mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
23 Commits
feat/copil
...
feat/githu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88eaab2baa | ||
|
|
4b0a445635 | ||
|
|
36312d2c6e | ||
|
|
d6d3b8d710 | ||
|
|
17d8d0bf05 | ||
|
|
5a2ab65f41 | ||
|
|
81a318de3e | ||
|
|
62c8e8634b | ||
|
|
b91c959cd9 | ||
|
|
5b95a2a1ef | ||
|
|
9c2a601167 | ||
|
|
b98e37bf23 | ||
|
|
fec8924361 | ||
|
|
712aee7302 | ||
|
|
bef292033e | ||
|
|
ec6974e3b8 | ||
|
|
2ef5e2fe77 | ||
|
|
0a8c7221ce | ||
|
|
840d1de636 | ||
|
|
ac55ab619b | ||
|
|
a8014d1e92 | ||
|
|
7de13c7713 | ||
|
|
9358b525a0 |
162
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
162
autogpt_platform/backend/backend/copilot/integration_creds.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -11,18 +11,34 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
# Shared technical notes that apply to both SDK and baseline modes
|
||||
_SHARED_TOOL_NOTES = """\
|
||||
|
||||
### Sharing files
|
||||
After `write_workspace_file`, embed the `download_url` in Markdown:
|
||||
- File: `[report.csv](workspace://file_id#text/csv)`
|
||||
- Image: ``
|
||||
- Video: ``
|
||||
### Sharing files with the user
|
||||
After saving a file to the persistent workspace with `write_workspace_file`,
|
||||
share it with the user by embedding the `download_url` from the response in
|
||||
your message as a Markdown link or image:
|
||||
|
||||
### File references — @@agptfile:
|
||||
Pass large file content to tools by reference: `@@agptfile:<uri>[<start>-<end>]`
|
||||
- `workspace://<file_id>` or `workspace:///<path>` — workspace files
|
||||
- `/absolute/path` — local/sandbox files
|
||||
- `[start-end]` — optional 1-indexed line range
|
||||
- Multiple refs per argument supported. Only `workspace://` and absolute paths are expanded.
|
||||
- **Any file** — shows as a clickable download link:
|
||||
`[report.csv](workspace://file_id#text/csv)`
|
||||
- **Image** — renders inline in chat:
|
||||
``
|
||||
- **Video** — renders inline in chat with player controls:
|
||||
``
|
||||
|
||||
The `download_url` field in the `write_workspace_file` response is already
|
||||
in the correct format — paste it directly after the `(` in the Markdown.
|
||||
|
||||
### Passing file content to tools — @@agptfile: references
|
||||
Instead of copying large file contents into a tool argument, pass a file
|
||||
reference and the platform will load the content for you.
|
||||
|
||||
Syntax: `@@agptfile:<uri>[<start>-<end>]`
|
||||
|
||||
- `<uri>` **must** start with `workspace://` or `/` (absolute path):
|
||||
- `workspace://<file_id>` — workspace file by ID
|
||||
- `workspace:///<path>` — workspace file by virtual path
|
||||
- `/absolute/local/path` — ephemeral or sdk_cwd file
|
||||
- E2B sandbox absolute path (e.g. `/home/user/script.py`)
|
||||
- `[<start>-<end>]` is an optional 1-indexed inclusive line range.
|
||||
- URIs that do not start with `workspace://` or `/` are **not** expanded.
|
||||
|
||||
Examples:
|
||||
```
|
||||
@@ -33,16 +49,69 @@ Examples:
|
||||
@@agptfile:/home/user/script.py
|
||||
```
|
||||
|
||||
**Structured data**: When the entire argument is a single file reference, the platform auto-parses by extension/MIME. Supported: JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel (.xlsx only). Unrecognised formats return plain string.
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Type coercion**: The platform auto-coerces expanded string values to match block input types (e.g. JSON string → `list[list[str]]`).
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Inputs with `"format": "file"` accept `workspace://<file_id>` or `data:<mime>;base64,<payload>`.
|
||||
Pass the `workspace://` URI directly (do NOT wrap in `@@agptfile:`). This avoids large payloads and preserves binary content.
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- Task tool: NEVER set `run_in_background` to true.
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -55,6 +124,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -69,6 +139,7 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -78,24 +149,40 @@ def _build_storage_supplement(
|
||||
|
||||
## Tool notes
|
||||
|
||||
### Shell & filesystem
|
||||
- Use `bash_exec` for shell commands ({sandbox_type}). Working dir: `{working_dir}`
|
||||
- All file tools share the same filesystem. Use relative or absolute paths under this dir.
|
||||
### Shell commands
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs {sandbox_type}.
|
||||
|
||||
### Working directory
|
||||
- Your working directory is: `{working_dir}`
|
||||
- All SDK file tools AND `bash_exec` operate on the same filesystem
|
||||
- Use relative paths or absolute paths under `{working_dir}` for all file operations
|
||||
|
||||
### Two storage systems — CRITICAL to understand
|
||||
|
||||
### Storage — important
|
||||
1. **{storage_system_1_name}** (`{working_dir}`):
|
||||
{characteristics}
|
||||
{persistence}
|
||||
2. **Persistent workspace** (cloud) — survives across sessions.
|
||||
- {file_move_name_1_to_2}: use `write_workspace_file`
|
||||
- {file_move_name_2_to_1}: use `read_workspace_file` with save_to_path
|
||||
- Save important files to workspace for persistence.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
2. **Persistent workspace** (cloud storage):
|
||||
- Files here **survive across sessions indefinitely**
|
||||
|
||||
### Moving files between storages
|
||||
- **{file_move_name_1_to_2}**: Copy to persistent workspace
|
||||
- **{file_move_name_2_to_1}**: Download for processing
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -113,7 +200,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -128,6 +219,7 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return await get_or_create_sandbox(
|
||||
sandbox = await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -783,7 +783,9 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
|
||||
@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -22,11 +22,13 @@ class AddUnderstandingTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Store user's business context, workflows, pain points, and automation goals. "
|
||||
"Call whenever the user shares business info. Each call incrementally merges "
|
||||
"with existing data — provide only the fields you have."
|
||||
)
|
||||
return """Capture and store information about the user's business context,
|
||||
workflows, pain points, and automation goals. Call this tool whenever the user
|
||||
shares information about their business. Each call incrementally adds to the
|
||||
existing understanding - you don't need to provide all fields at once.
|
||||
|
||||
Use this to build a comprehensive profile that helps recommend better agents
|
||||
and automations for the user's specific needs."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -408,11 +408,18 @@ class BrowserNavigateTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Navigate to a URL in a real browser. Returns accessibility tree with @ref IDs "
|
||||
"for browser_act. Session persists (cookies/auth carry over). "
|
||||
"For static pages, prefer web_fetch. "
|
||||
"For SPAs, elements may load late — use browser_act with wait + browser_screenshot to verify. "
|
||||
"For auth: navigate to login, fill creds with browser_act, then navigate to target."
|
||||
"Navigate to a URL using a real browser. Returns an accessibility "
|
||||
"tree snapshot listing the page's interactive elements with @ref IDs "
|
||||
"(e.g. @e3) that can be used with browser_act. "
|
||||
"Session persists — cookies and login state carry over between calls. "
|
||||
"Use this (with browser_act) for multi-step interaction: login flows, "
|
||||
"form filling, button clicks, or anything requiring page interaction. "
|
||||
"For plain static pages, prefer web_fetch — no browser overhead. "
|
||||
"For authenticated pages: navigate to the login page first, use browser_act "
|
||||
"to fill credentials and submit, then navigate to the target page. "
|
||||
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
|
||||
"state. If elements seem missing, use browser_act with action='wait' and a "
|
||||
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -422,13 +429,13 @@ class BrowserNavigateTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "HTTP/HTTPS URL to navigate to.",
|
||||
"description": "The HTTP/HTTPS URL to navigate to.",
|
||||
},
|
||||
"wait_for": {
|
||||
"type": "string",
|
||||
"enum": ["networkidle", "load", "domcontentloaded"],
|
||||
"default": "networkidle",
|
||||
"description": "Navigation completion strategy (default: networkidle).",
|
||||
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
@@ -547,12 +554,14 @@ class BrowserActTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with the current browser page using @ref IDs from the snapshot. "
|
||||
"Actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"Interact with the current browser page. Use @ref IDs from the "
|
||||
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
|
||||
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
|
||||
"check, uncheck, select, wait, back, forward, reload. "
|
||||
"fill clears field first; type appends. "
|
||||
"wait accepts CSS selector or milliseconds (e.g. '1000'). "
|
||||
"Returns updated snapshot."
|
||||
"fill clears the field before typing; type appends without clearing. "
|
||||
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
|
||||
"Example login flow: fill @e1 with email → fill @e2 with password → "
|
||||
"click @e3 (submit) → browser_navigate to the target page."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -578,21 +587,30 @@ class BrowserActTool(BaseTool):
|
||||
"forward",
|
||||
"reload",
|
||||
],
|
||||
"description": "Action to perform.",
|
||||
"description": "The action to perform.",
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "@ref ID (e.g. '@e3'), CSS selector, or text description.",
|
||||
"description": (
|
||||
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
|
||||
"a CSS selector, or a text description. "
|
||||
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
|
||||
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
|
||||
),
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Text for fill/type, key for press (e.g. 'Enter'), option for select.",
|
||||
"description": (
|
||||
"For fill/type: the text to enter. "
|
||||
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
|
||||
"For select: the option value to select."
|
||||
),
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"default": "down",
|
||||
"description": "Scroll direction (default: down).",
|
||||
"description": "For scroll: direction to scroll.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
@@ -739,10 +757,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Screenshot the current browser page and save to workspace. "
|
||||
"annotate=true overlays @ref labels on elements. "
|
||||
"IMPORTANT: After calling, you MUST immediately call read_workspace_file with the "
|
||||
"returned file_id to display the image inline."
|
||||
"Take a screenshot of the current browser page and save it to the workspace. "
|
||||
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
|
||||
"with the returned file_id to display the image inline to the user — "
|
||||
"the screenshot is not visible until you do this. "
|
||||
"With annotate=true (default), @ref labels are overlaid on interactive "
|
||||
"elements, making it easy to see which @ref ID maps to which element on screen."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -753,12 +773,12 @@ class BrowserScreenshotTool(BaseTool):
|
||||
"annotate": {
|
||||
"type": "boolean",
|
||||
"default": True,
|
||||
"description": "Overlay @ref labels (default: true).",
|
||||
"description": "Overlay @ref labels on interactive elements (default: true).",
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"default": "screenshot.png",
|
||||
"description": "Workspace filename (default: screenshot.png).",
|
||||
"description": "Filename to save in the workspace.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,12 +108,22 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Retrieve execution outputs from a library agent. "
|
||||
"Identify by agent_name, library_agent_id, or store_slug. "
|
||||
"Filter by execution_id or run_time. "
|
||||
"Optionally wait for running executions."
|
||||
)
|
||||
return """Retrieve execution outputs from agents in the user's library.
|
||||
|
||||
Identify the agent using one of:
|
||||
- agent_name: Fuzzy search in user's library
|
||||
- library_agent_id: Exact library agent ID
|
||||
- store_slug: Marketplace format 'username/agent-name'
|
||||
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for completion.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -122,27 +132,32 @@ class AgentOutputTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_name": {
|
||||
"type": "string",
|
||||
"description": "Agent name (fuzzy match).",
|
||||
"description": "Agent name to search for in user's library (fuzzy match)",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID.",
|
||||
"description": "Exact library agent ID",
|
||||
},
|
||||
"store_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace 'username/agent-slug'.",
|
||||
"description": "Marketplace identifier: 'username/agent-slug'",
|
||||
},
|
||||
"execution_id": {
|
||||
"type": "string",
|
||||
"description": "Specific execution ID.",
|
||||
"description": "Specific execution ID to retrieve",
|
||||
},
|
||||
"run_time": {
|
||||
"type": "string",
|
||||
"description": "Time filter: 'latest', today/yesterday/last week/last 7 days/last month/last 30 days, 'YYYY-MM-DD', or ISO datetime.",
|
||||
"description": (
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait if still running (0-300). Returns current state on timeout.",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -41,9 +42,15 @@ class BashExecTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a Bash command or script. Shares filesystem with SDK file tools. "
|
||||
"Useful for scripts, data processing, and package installation. "
|
||||
"Killed after timeout (default 30s, max 120s)."
|
||||
"Execute a Bash command or script. "
|
||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
||||
"functions, etc.). "
|
||||
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
|
||||
"tools — files created by either are immediately visible to both. "
|
||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
||||
"Returns stdout and stderr. "
|
||||
"Useful for file manipulation, data processing, running scripts, "
|
||||
"and installing packages."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -53,11 +60,13 @@ class BashExecTool(BaseTool):
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command or script.",
|
||||
"description": "Bash command or script to execute.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds (default 30, max 120).",
|
||||
"description": (
|
||||
"Max execution time in seconds (default 30, max 120)."
|
||||
),
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
@@ -88,7 +97,9 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -125,14 +136,27 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
envs=envs,
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -30,7 +30,12 @@ class ContinueRunBlockTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Resume block execution after human review approval. Pass the review_id."
|
||||
return (
|
||||
"Continue executing a block after human review approval. "
|
||||
"Use this after a run_block call returned review_required. "
|
||||
"Pass the review_id from the review_required response. "
|
||||
"The block will execute with the original pre-approved input data."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -39,7 +44,10 @@ class ContinueRunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"review_id": {
|
||||
"type": "string",
|
||||
"description": "review_id from the review_required response.",
|
||||
"description": (
|
||||
"The review_id from a previous review_required response. "
|
||||
"This resumes execution with the pre-approved input data."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["review_id"],
|
||||
|
||||
@@ -23,8 +23,12 @@ class CreateAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new agent from JSON (nodes + links). Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
"Create a new agent workflow. Pass `agent_json` with the complete "
|
||||
"agent graph JSON you generated using block schemas from find_block. "
|
||||
"The tool validates, auto-fixes, and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, search for relevant existing agents "
|
||||
"using find_library_agent that could be used as building blocks. "
|
||||
"Pass their IDs in the library_agent_ids parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -38,21 +42,34 @@ class CreateAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Agent graph with 'nodes' and 'links' arrays.",
|
||||
"description": (
|
||||
"The agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links' arrays, and optionally "
|
||||
"'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the agent. Default is true. "
|
||||
"Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,7 +23,9 @@ class CustomizeAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace/template agent. Validates, auto-fixes, and saves."
|
||||
"Customize a marketplace or template agent. Pass `agent_json` "
|
||||
"with the complete customized agent JSON. The tool validates, "
|
||||
"auto-fixes, and saves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -37,21 +39,32 @@ class CustomizeAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Customized agent JSON with nodes and links.",
|
||||
"description": (
|
||||
"Complete customized agent JSON to validate and save. "
|
||||
"Optionally include 'name' and 'description'."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save the agent (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the customized agent. Default is true."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID to save into (default: root).",
|
||||
"description": (
|
||||
"Optional folder ID to save the agent into. "
|
||||
"If not provided, the agent is saved at root level. "
|
||||
"Use list_folders to find available folders."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -23,8 +23,12 @@ class EditAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit an existing agent. Validates, auto-fixes, and saves. "
|
||||
"Before calling, search for existing agents with find_library_agent."
|
||||
"Edit an existing agent. Pass `agent_json` with the complete "
|
||||
"updated agent JSON you generated. The tool validates, auto-fixes, "
|
||||
"and saves.\n\n"
|
||||
"IMPORTANT: Before calling this tool, if the changes involve adding new "
|
||||
"functionality, search for relevant existing agents using find_library_agent "
|
||||
"that could be used as building blocks."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -38,20 +42,33 @@ class EditAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Graph ID or library agent ID to edit.",
|
||||
"description": (
|
||||
"The ID of the agent to edit. "
|
||||
"Can be a graph ID or library agent ID."
|
||||
),
|
||||
},
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Updated agent JSON with nodes and links.",
|
||||
"description": (
|
||||
"Complete updated agent JSON to validate and save. "
|
||||
"Must contain 'nodes' and 'links'. "
|
||||
"Include 'name' and/or 'description' if they need "
|
||||
"to be updated."
|
||||
),
|
||||
},
|
||||
"library_agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs as building blocks.",
|
||||
"description": (
|
||||
"List of library agent IDs to use as building blocks for the changes."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": "Save changes (default: true). False for preview.",
|
||||
"description": (
|
||||
"Whether to save the changes. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -134,7 +134,11 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search existing feature requests. Check before creating a new one."
|
||||
return (
|
||||
"Search existing feature requests to check if a similar request "
|
||||
"already exists before creating a new one. Returns matching feature "
|
||||
"requests with their ID, title, and description."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -230,9 +234,14 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a feature request or add need to existing one. "
|
||||
"Search first to avoid duplicates. Pass existing_issue_id to add to existing. "
|
||||
"Never include PII (names, emails, phone numbers, company names) in title/description."
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -242,15 +251,28 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Feature request title. No PII.",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "What the user wants and why. No PII.",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
"description": "Linear issue ID to add need to (from search results).",
|
||||
"description": (
|
||||
"If adding a need to an existing feature request, "
|
||||
"provide its Linear issue ID (from search results). "
|
||||
"Omit to create a new feature request."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["title", "description"],
|
||||
|
||||
@@ -18,7 +18,9 @@ class FindAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search marketplace agents by capability."
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -27,7 +29,7 @@ class FindAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords (single keywords work best).",
|
||||
"description": "Search query describing what the user wants to accomplish. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -51,7 +51,14 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search blocks by name or description. Returns block IDs for run_block. Always call this FIRST to get block IDs before using run_block."
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
"The response includes each block's id, name, and description. "
|
||||
"Call run_block with the block's id **with no inputs** to see detailed inputs/outputs and execute it."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -60,11 +67,18 @@ class FindBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search keywords (e.g. 'email', 'http', 'ai').",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
"include_schemas": {
|
||||
"type": "boolean",
|
||||
"description": "Include full input/output schemas (for agent JSON generation).",
|
||||
"description": (
|
||||
"If true, include full input_schema and output_schema "
|
||||
"for each block. Use when generating agent JSON that "
|
||||
"needs block schemas. Default is false."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -19,8 +19,13 @@ class FindLibraryAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search user's library agents. Returns graph_id, schemas for sub-agent composition. "
|
||||
"Omit query to list all."
|
||||
"Search for or list agents in the user's library. Use this to find "
|
||||
"agents the user has already added to their library, including agents "
|
||||
"they created or added from the marketplace. "
|
||||
"When creating agents with sub-agent composition, use this to get "
|
||||
"the agent's graph_id, graph_version, input_schema, and output_schema "
|
||||
"needed for AgentExecutorBlock nodes. "
|
||||
"Omit the query to list all agents."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -30,7 +35,10 @@ class FindLibraryAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search by name/description. Omit to list all.",
|
||||
"description": (
|
||||
"Search query to find agents by name or description. "
|
||||
"Omit to list all agents in the library."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -22,8 +22,20 @@ class FixAgentGraphTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Auto-fix common agent JSON issues (UUIDs, types, credentials, spacing, etc.). "
|
||||
"Returns fixed JSON and list of fixes applied."
|
||||
"Auto-fix common issues in an agent JSON graph. Applies fixes for:\n"
|
||||
"- Missing or invalid UUIDs on nodes and links\n"
|
||||
"- StoreValueBlock prerequisites for ConditionBlock\n"
|
||||
"- Double curly brace escaping in prompt templates\n"
|
||||
"- AddToList/AddToDictionary prerequisite blocks\n"
|
||||
"- CodeExecutionBlock output field naming\n"
|
||||
"- Missing credentials configuration\n"
|
||||
"- Node X coordinate spacing (800+ units apart)\n"
|
||||
"- AI model default parameters\n"
|
||||
"- Link static properties based on input schema\n"
|
||||
"- Type mismatches (inserts conversion blocks)\n\n"
|
||||
"Returns the fixed agent JSON plus a list of fixes applied. "
|
||||
"After fixing, the agent is re-validated. If still invalid, "
|
||||
"the remaining errors are included in the response."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -42,7 +42,12 @@ class GetAgentBuildingGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get the agent JSON building guide (nodes, links, AgentExecutorBlock, MCPToolBlock usage). Call before generating agent JSON."
|
||||
return (
|
||||
"Returns the complete guide for building agent JSON graphs, including "
|
||||
"block IDs, link structure, AgentInputBlock, AgentOutputBlock, "
|
||||
"AgentExecutorBlock (for sub-agent composition), and MCPToolBlock usage. "
|
||||
"Call this before generating agent JSON to ensure correct structure."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -25,7 +25,8 @@ class GetDocPageTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read full documentation page content by path (from search_docs results)."
|
||||
"Get the full content of a documentation page by its path. "
|
||||
"Use this after search_docs to read the complete content of a relevant page."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -35,7 +36,10 @@ class GetDocPageTool(BaseTool):
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Doc file path (e.g. 'platform/block-sdk-guide.md').",
|
||||
"description": (
|
||||
"The path to the documentation file, as returned by search_docs. "
|
||||
"Example: 'platform/block-sdk-guide.md'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
|
||||
@@ -38,7 +38,11 @@ class GetMCPGuideTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get MCP server URLs and auth guide."
|
||||
return (
|
||||
"Returns the MCP tool guide: known hosted server URLs (Notion, Linear, "
|
||||
"Stripe, Intercom, Cloudflare, Atlassian) and authentication workflow. "
|
||||
"Call before using run_mcp_tool if you need a server URL or auth info."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
|
||||
@@ -88,7 +88,10 @@ class CreateFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Create a library folder. Use parent_id to nest inside another folder."
|
||||
return (
|
||||
"Create a new folder in the user's library to organize agents. "
|
||||
"Optionally nest it inside an existing folder using parent_id."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -101,19 +104,22 @@ class CreateFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Folder name (max 100 chars).",
|
||||
"description": "Name for the new folder (max 100 chars).",
|
||||
},
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": "Parent folder ID (omit for root).",
|
||||
"description": (
|
||||
"ID of the parent folder to nest inside. "
|
||||
"Omit to create at root level."
|
||||
),
|
||||
},
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"description": "Icon identifier.",
|
||||
"description": "Optional icon identifier for the folder.",
|
||||
},
|
||||
"color": {
|
||||
"type": "string",
|
||||
"description": "Hex color (#RRGGBB).",
|
||||
"description": "Optional hex color code (#RRGGBB).",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
@@ -169,8 +175,13 @@ class ListFoldersTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List library folders. Omit parent_id for full tree. "
|
||||
"Set include_agents=true when user asks about agents in folders."
|
||||
"List the user's library folders. "
|
||||
"Omit parent_id to get the full folder tree. "
|
||||
"Provide parent_id to list only direct children of that folder. "
|
||||
"Set include_agents=true to also return the agents inside each folder "
|
||||
"and root-level agents not in any folder. Always set include_agents=true "
|
||||
"when the user asks about agents, wants to see what's in their folders, "
|
||||
"or mentions agents alongside folders."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -184,11 +195,17 @@ class ListFoldersTool(BaseTool):
|
||||
"properties": {
|
||||
"parent_id": {
|
||||
"type": "string",
|
||||
"description": "List children of this folder (omit for full tree).",
|
||||
"description": (
|
||||
"List children of this folder. "
|
||||
"Omit to get the full folder tree."
|
||||
),
|
||||
},
|
||||
"include_agents": {
|
||||
"type": "boolean",
|
||||
"description": "Include agents in each folder (default: false).",
|
||||
"description": (
|
||||
"Whether to include the list of agents inside each folder. "
|
||||
"Defaults to false."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -340,7 +357,10 @@ class MoveFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Move a folder. Set target_parent_id to null for root."
|
||||
return (
|
||||
"Move a folder to a different parent folder. "
|
||||
"Set target_parent_id to null to move to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -353,11 +373,14 @@ class MoveFolderTool(BaseTool):
|
||||
"properties": {
|
||||
"folder_id": {
|
||||
"type": "string",
|
||||
"description": "Folder ID.",
|
||||
"description": "ID of the folder to move.",
|
||||
},
|
||||
"target_parent_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": "New parent folder ID (null for root).",
|
||||
"description": (
|
||||
"ID of the new parent folder. "
|
||||
"Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["folder_id"],
|
||||
@@ -410,7 +433,10 @@ class DeleteFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Delete a folder. Agents inside move to root (not deleted)."
|
||||
return (
|
||||
"Delete a folder from the user's library. "
|
||||
"Agents inside the folder are moved to root level (not deleted)."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -473,7 +499,10 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Move agents to a folder. Set folder_id to null for root."
|
||||
return (
|
||||
"Move one or more agents to a folder. "
|
||||
"Set folder_id to null to move agents to root level."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -487,11 +516,13 @@ class MoveAgentsToFolderTool(BaseTool):
|
||||
"agent_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Library agent IDs to move.",
|
||||
"description": "List of library agent IDs to move.",
|
||||
},
|
||||
"folder_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Target folder ID (null for root).",
|
||||
"description": (
|
||||
"Target folder ID. Use null to move to root level."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_ids"],
|
||||
|
||||
@@ -104,11 +104,19 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Run or schedule an agent. Automatically checks inputs and credentials. "
|
||||
"Identify by username_agent_slug ('user/agent') or library_agent_id. "
|
||||
"For scheduling, provide schedule_name + cron."
|
||||
)
|
||||
return """Run or schedule an agent from the marketplace or user's library.
|
||||
|
||||
The tool automatically handles the setup flow:
|
||||
- Returns missing inputs if required fields are not provided
|
||||
- Returns missing credentials if user needs to configure them
|
||||
- Executes immediately if all requirements are met
|
||||
- Schedules execution if cron expression is provided
|
||||
|
||||
Identify the agent using either:
|
||||
- username_agent_slug: Marketplace format 'username/agent-name'
|
||||
- library_agent_id: ID of an agent in the user's library
|
||||
|
||||
For scheduled execution, provide: schedule_name, cron, and optionally timezone."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -117,36 +125,40 @@ class RunAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"username_agent_slug": {
|
||||
"type": "string",
|
||||
"description": "Marketplace format 'username/agent-name'.",
|
||||
"description": "Agent identifier in format 'username/agent-name'",
|
||||
},
|
||||
"library_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Library agent ID.",
|
||||
"description": "Library agent ID from user's library",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Input values for the agent.",
|
||||
"description": "Input values for the agent",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"use_defaults": {
|
||||
"type": "boolean",
|
||||
"description": "Run with default values (confirm with user first).",
|
||||
"description": "Set to true to run with default values (user must confirm)",
|
||||
},
|
||||
"schedule_name": {
|
||||
"type": "string",
|
||||
"description": "Name for scheduled execution.",
|
||||
"description": "Name for scheduled execution (triggers scheduling mode)",
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron expression (min hour day month weekday).",
|
||||
"description": "Cron expression (5 fields: min hour day month weekday)",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone (default: UTC).",
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to wait for completion (0-300).",
|
||||
"description": (
|
||||
"Max seconds to wait for execution to complete (0-300). "
|
||||
"If >0, blocks until the execution finishes or times out. "
|
||||
"Returns execution outputs when complete."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -45,10 +45,13 @@ class RunBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a block. IMPORTANT: Always get block_id from find_block first "
|
||||
"— do NOT guess or fabricate IDs. "
|
||||
"Call with empty input_data to see schema, then with data to execute. "
|
||||
"If review_required, use continue_run_block."
|
||||
"Execute a specific block with the provided input data. "
|
||||
"IMPORTANT: You MUST call find_block first to get the block's 'id' - "
|
||||
"do NOT guess or make up block IDs. "
|
||||
"On first attempt (without input_data), returns detailed schema showing "
|
||||
"required inputs and outputs. Then call again with proper input_data to execute. "
|
||||
"If a block requires human review, use continue_run_block with the "
|
||||
"review_id after the user approves."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -58,14 +61,28 @@ class RunBlockTool(BaseTool):
|
||||
"properties": {
|
||||
"block_id": {
|
||||
"type": "string",
|
||||
"description": "Block ID from find_block results.",
|
||||
"description": (
|
||||
"The block's 'id' field from find_block results. "
|
||||
"NEVER guess this - always get it from find_block first."
|
||||
),
|
||||
},
|
||||
"block_name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The block's human-readable name from find_block results. "
|
||||
"Used for display purposes in the UI."
|
||||
),
|
||||
},
|
||||
"input_data": {
|
||||
"type": "object",
|
||||
"description": "Input values. Use {} first to see schema.",
|
||||
"description": (
|
||||
"Input values for the block. "
|
||||
"First call with empty {} to see the block's schema, "
|
||||
"then call again with proper values to execute."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["block_id", "input_data"],
|
||||
"required": ["block_id", "block_name", "input_data"],
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -57,9 +57,10 @@ class RunMCPToolTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover and execute MCP server tools. "
|
||||
"Call with server_url only to list tools, then with tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide first for server URLs and auth."
|
||||
"Connect to an MCP (Model Context Protocol) server to discover and execute its tools. "
|
||||
"Two-step: (1) call with server_url to list available tools, "
|
||||
"(2) call again with server_url + tool_name + tool_arguments to execute. "
|
||||
"Call get_mcp_guide for known server URLs and auth details."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -69,15 +70,24 @@ class RunMCPToolTool(BaseTool):
|
||||
"properties": {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "MCP server URL (Streamable HTTP endpoint).",
|
||||
"description": (
|
||||
"URL of the MCP server (Streamable HTTP endpoint), "
|
||||
"e.g. https://mcp.example.com/mcp"
|
||||
),
|
||||
},
|
||||
"tool_name": {
|
||||
"type": "string",
|
||||
"description": "Tool to execute. Omit to discover available tools.",
|
||||
"description": (
|
||||
"Name of the MCP tool to execute. "
|
||||
"Omit on first call to discover available tools."
|
||||
),
|
||||
},
|
||||
"tool_arguments": {
|
||||
"type": "object",
|
||||
"description": "Arguments matching the tool's input schema.",
|
||||
"description": (
|
||||
"Arguments to pass to the selected tool. "
|
||||
"Must match the tool's input schema returned during discovery."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["server_url"],
|
||||
|
||||
@@ -38,7 +38,11 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Search platform documentation by keyword. Use get_doc_page to read full results."
|
||||
return (
|
||||
"Search the AutoGPT platform documentation for information about "
|
||||
"how to use the platform, build agents, configure blocks, and more. "
|
||||
"Returns relevant documentation sections. Use get_doc_page to read full content."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -47,7 +51,10 @@ class SearchDocsTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Documentation search query.",
|
||||
"description": (
|
||||
"Search query to find relevant documentation. "
|
||||
"Use natural language to describe what you're looking for."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Schema regression tests for all registered CoPilot tools.
|
||||
|
||||
Validates that every tool in TOOL_REGISTRY produces a well-formed schema:
|
||||
- description is non-empty
|
||||
- all `required` fields exist in `properties`
|
||||
- every property has a `type` and `description`
|
||||
- total token budget does not regress past 8000 tokens
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
_TOKEN_BUDGET = 8_000
|
||||
|
||||
|
||||
def _get_all_tool_schemas() -> list[tuple[str, object]]:
|
||||
"""Return (tool_name, openai_schema) pairs for every registered tool."""
|
||||
return [(name, tool.as_openai_tool()) for name, tool in TOOL_REGISTRY.items()]
|
||||
|
||||
|
||||
_ALL_SCHEMAS = _get_all_tool_schemas()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_name,schema",
|
||||
_ALL_SCHEMAS,
|
||||
ids=[name for name, _ in _ALL_SCHEMAS],
|
||||
)
|
||||
class TestToolSchema:
|
||||
"""Validate schema invariants for every registered tool."""
|
||||
|
||||
def test_description_non_empty(self, tool_name: str, schema: dict) -> None:
|
||||
desc = schema["function"].get("description", "")
|
||||
assert desc, f"Tool '{tool_name}' has an empty description"
|
||||
|
||||
def test_required_fields_exist_in_properties(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
for field in required:
|
||||
assert field in properties, (
|
||||
f"Tool '{tool_name}': required field '{field}' "
|
||||
f"not found in properties {list(properties.keys())}"
|
||||
)
|
||||
|
||||
def test_every_property_has_type_and_description(
|
||||
self, tool_name: str, schema: dict
|
||||
) -> None:
|
||||
params = schema["function"].get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
for prop_name, prop_def in properties.items():
|
||||
assert (
|
||||
"type" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'type'"
|
||||
assert (
|
||||
"description" in prop_def
|
||||
), f"Tool '{tool_name}', property '{prop_name}' is missing 'description'"
|
||||
|
||||
|
||||
def test_total_schema_token_budget() -> None:
|
||||
"""Assert total tool schema size stays under the token budget.
|
||||
|
||||
This locks in the 34% token reduction from #12398 and prevents future
|
||||
description bloat from eroding the gains. Budget is set to 8000 tokens.
|
||||
Note: this measures tool JSON only (not the full system prompt); the actual
|
||||
baseline for tool schemas alone is ~6470 tokens, giving ~19% headroom.
|
||||
"""
|
||||
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
|
||||
serialized = json.dumps(schemas)
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
total_tokens = len(enc.encode(serialized))
|
||||
assert total_tokens < _TOKEN_BUDGET, (
|
||||
f"Tool schemas use {total_tokens} tokens, exceeding budget of {_TOKEN_BUDGET}. "
|
||||
f"Description bloat detected — trim descriptions or raise the budget intentionally."
|
||||
)
|
||||
@@ -21,7 +21,19 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Validate agent JSON for correctness (block_ids, links, types, schemas). On failure, use fix_agent_graph to auto-fix."
|
||||
return (
|
||||
"Validate an agent JSON graph for correctness. Checks:\n"
|
||||
"- All block_ids reference real blocks\n"
|
||||
"- All links reference valid source/sink nodes and fields\n"
|
||||
"- Required input fields are wired or have defaults\n"
|
||||
"- Data types are compatible across links\n"
|
||||
"- Nested sink links use correct notation\n"
|
||||
"- Prompt templates use proper curly brace escaping\n"
|
||||
"- AgentExecutorBlock configurations are valid\n\n"
|
||||
"Call this after generating agent JSON to verify correctness. "
|
||||
"If validation fails, either fix issues manually based on the error "
|
||||
"descriptions, or call fix_agent_graph to auto-fix common problems."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
@@ -34,7 +46,11 @@ class ValidateAgentGraphTool(BaseTool):
|
||||
"properties": {
|
||||
"agent_json": {
|
||||
"type": "object",
|
||||
"description": "Agent JSON with 'nodes' and 'links' arrays.",
|
||||
"description": (
|
||||
"The agent JSON to validate. Must contain 'nodes' and 'links' arrays. "
|
||||
"Each node needs: id (UUID), block_id, input_default, metadata. "
|
||||
"Each link needs: id (UUID), source_id, source_name, sink_id, sink_name."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["agent_json"],
|
||||
|
||||
@@ -59,7 +59,13 @@ class WebFetchTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Fetch a public web page. Public URLs only — internal addresses blocked. Returns readable text from HTML by default."
|
||||
return (
|
||||
"Fetch the content of a public web page by URL. "
|
||||
"Returns readable text extracted from HTML by default. "
|
||||
"Useful for reading documentation, articles, and API responses. "
|
||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
||||
"(private/internal network addresses are blocked)."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -68,11 +74,14 @@ class WebFetchTool(BaseTool):
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Public HTTP/HTTPS URL.",
|
||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
||||
},
|
||||
"extract_text": {
|
||||
"type": "boolean",
|
||||
"description": "Extract text from HTML (default: true).",
|
||||
"description": (
|
||||
"If true (default), extract readable text from HTML. "
|
||||
"If false, return raw content."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -321,7 +321,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List persistent workspace files. For ephemeral session files, use SDK Glob/Read instead. Optionally filter by path prefix."
|
||||
return (
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -330,17 +336,24 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": "Filter by path prefix (e.g. '/documents/').",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max files to return (default 50, max 100).",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": "Include files from all sessions (default: false).",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
@@ -423,10 +436,18 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from persistent workspace. Specify file_id or path. "
|
||||
"Small text/image files return inline; large/binary return metadata+URL. "
|
||||
"Use save_to_path to copy to working dir for processing. "
|
||||
"Use offset/length for paginated reads."
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Use 'save_to_path' to copy the file to the working directory "
|
||||
"(sandbox or ephemeral) for processing with bash_exec or file tools. "
|
||||
"Use 'offset' and 'length' for paginated reads of large files "
|
||||
"(e.g., persisted tool outputs). "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -436,30 +457,48 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual file path (e.g. '/documents/report.pdf').",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": "Copy file to this working directory path for processing.",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the working "
|
||||
"directory (cloud sandbox when E2B is active, or "
|
||||
"ephemeral dir otherwise) so it can be processed with "
|
||||
"bash_exec or file tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": "Always return metadata+URL instead of inline content.",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Character offset for paginated reads (0-based).",
|
||||
"description": (
|
||||
"Character offset to start reading from (0-based). "
|
||||
"Use with 'length' for paginated reads of large files."
|
||||
),
|
||||
},
|
||||
"length": {
|
||||
"type": "integer",
|
||||
"description": "Max characters to return for paginated reads.",
|
||||
"description": (
|
||||
"Maximum number of characters to return. "
|
||||
"Defaults to full file. Use with 'offset' for paginated reads."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -614,9 +653,15 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write a file to persistent workspace (survives across sessions). "
|
||||
"Provide exactly one of: content (text), content_base64 (binary), "
|
||||
f"or source_path (copy from working dir). Max {Config().max_file_size_mb}MB."
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -626,31 +671,51 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Filename (e.g. 'report.pdf').",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Plain text content. Mutually exclusive with content_base64/source_path.",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded binary content. Mutually exclusive with content/source_path.",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": "Working directory path to copy to workspace. Mutually exclusive with content/content_base64.",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual path (e.g. '/documents/report.pdf'). Defaults to '/{filename}'.",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "MIME type. Auto-detected from filename if omitted.",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Overwrite if file exists (default: false).",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
@@ -777,7 +842,12 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Delete a file from persistent workspace. Specify file_id or path."
|
||||
return (
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -786,14 +856,17 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "File ID from list_workspace_files.",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Virtual file path.",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one of file_id or path must be provided
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -25,6 +25,35 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -69,7 +98,11 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -156,6 +189,8 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -168,10 +203,17 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -129,6 +130,8 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,12 +23,16 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -68,13 +72,17 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
Reference in New Issue
Block a user