mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
8 Commits
feat/githu
...
feat/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9190a28f24 | ||
|
|
0d21179a25 | ||
|
|
a7c13d676c | ||
|
|
d68de002f5 | ||
|
|
ea43bdf695 | ||
|
|
bbc4d9194f | ||
|
|
98a0d7dcc5 | ||
|
|
f2676de9d0 |
@@ -11,10 +11,7 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import parse_data_uri, resolve_media_content
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
from ._api import get_api
|
||||
from ._auth import (
|
||||
@@ -181,8 +178,7 @@ class FileOperation(StrEnum):
|
||||
|
||||
class FileOperationInput(TypedDict):
|
||||
path: str
|
||||
# MediaFileType is a str NewType — no runtime breakage for existing callers.
|
||||
content: MediaFileType
|
||||
content: str
|
||||
operation: FileOperation
|
||||
|
||||
|
||||
@@ -279,11 +275,11 @@ class GithubMultiFileCommitBlock(Block):
|
||||
base_tree_sha = commit_data["tree"]["sha"]
|
||||
|
||||
# 3. Build tree entries for each file operation (blobs created concurrently)
|
||||
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
|
||||
async def _create_blob(content: str) -> str:
|
||||
blob_url = repo_url + "/git/blobs"
|
||||
blob_response = await api.post(
|
||||
blob_url,
|
||||
json={"content": content, "encoding": encoding},
|
||||
json={"content": content, "encoding": "utf-8"},
|
||||
)
|
||||
return blob_response.json()["sha"]
|
||||
|
||||
@@ -305,19 +301,10 @@ class GithubMultiFileCommitBlock(Block):
|
||||
else:
|
||||
upsert_files.append((path, file_op.get("content", "")))
|
||||
|
||||
# Create all blobs concurrently. Data URIs (from store_media_file)
|
||||
# are sent as base64 blobs to preserve binary content.
|
||||
# Create all blobs concurrently
|
||||
if upsert_files:
|
||||
|
||||
async def _make_blob(content: str) -> str:
|
||||
parsed = parse_data_uri(content)
|
||||
if parsed is not None:
|
||||
_, b64_payload = parsed
|
||||
return await _create_blob(b64_payload, encoding="base64")
|
||||
return await _create_blob(content)
|
||||
|
||||
blob_shas = await asyncio.gather(
|
||||
*[_make_blob(content) for _, content in upsert_files]
|
||||
*[_create_blob(content) for _, content in upsert_files]
|
||||
)
|
||||
for (path, _), blob_sha in zip(upsert_files, blob_shas):
|
||||
tree_entries.append(
|
||||
@@ -371,36 +358,15 @@ class GithubMultiFileCommitBlock(Block):
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
# Resolve media references (workspace://, data:, URLs) to data
|
||||
# URIs so _make_blob can send binary content correctly.
|
||||
resolved_files: list[FileOperationInput] = []
|
||||
for file_op in input_data.files:
|
||||
content = file_op.get("content", "")
|
||||
operation = FileOperation(file_op.get("operation", "upsert"))
|
||||
if operation != FileOperation.DELETE:
|
||||
content = await resolve_media_content(
|
||||
MediaFileType(content),
|
||||
execution_context,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
resolved_files.append(
|
||||
FileOperationInput(
|
||||
path=file_op["path"],
|
||||
content=MediaFileType(content),
|
||||
operation=operation,
|
||||
)
|
||||
)
|
||||
|
||||
sha, url = await self.multi_file_commit(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
resolved_files,
|
||||
input_data.files,
|
||||
)
|
||||
yield "sha", sha
|
||||
yield "url", url
|
||||
|
||||
@@ -8,7 +8,6 @@ from backend.blocks.github.pull_requests import (
|
||||
GithubMergePullRequestBlock,
|
||||
prepare_pr_api_url,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
# ── prepare_pr_api_url tests ──
|
||||
@@ -98,11 +97,7 @@ async def test_multi_file_commit_error_path():
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
}
|
||||
with pytest.raises(BlockExecutionError, match="ref update failed"):
|
||||
async for _ in block.execute(
|
||||
input_data,
|
||||
credentials=TEST_CREDENTIALS,
|
||||
execution_context=ExecutionContext(),
|
||||
):
|
||||
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
@@ -84,17 +82,6 @@ def resolve_sandbox_path(path: str) -> str:
|
||||
return normalized
|
||||
|
||||
|
||||
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped :class:`WorkspaceManager`.
|
||||
|
||||
Placed here (rather than in ``tools/workspace_files``) so that modules
|
||||
like ``sdk/file_ref`` can import it without triggering the heavy
|
||||
``tools/__init__`` import chain.
|
||||
"""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
"""Return True if *path* is within an allowed host-filesystem location.
|
||||
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Integration credential lookup with per-process TTL cache.
|
||||
|
||||
Provides token retrieval for connected integrations so that copilot tools
|
||||
(e.g. bash_exec) can inject auth tokens into the execution environment without
|
||||
hitting the database on every command.
|
||||
|
||||
Cache semantics (handled automatically by TTLCache):
|
||||
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
|
||||
for users who have credentials and are running many bash commands.
|
||||
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
|
||||
on every E2B command for users who haven't connected an account yet, while
|
||||
still picking up a newly-connected account within one minute.
|
||||
|
||||
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
|
||||
least-recently-used entry when the limit is reached.
|
||||
|
||||
Multi-worker note: both caches are in-process only. Each worker/replica
|
||||
maintains its own independent cache, so a credential fetch may be duplicated
|
||||
across processes. This is acceptable for the current goal (reduce DB hits per
|
||||
session per-process), but if cache efficiency across replicas becomes important
|
||||
a shared cache (e.g. Redis) should be used instead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
register_creds_changed_hook,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps provider slug → env var names to inject when the provider is connected.
|
||||
# Add new providers here when adding integration support.
|
||||
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
|
||||
# must be updated when adding a new provider.
|
||||
PROVIDER_ENV_VARS: dict[str, list[str]] = {
|
||||
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
|
||||
}
|
||||
|
||||
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
|
||||
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
|
||||
_CACHE_MAX_SIZE = 10_000
|
||||
|
||||
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
|
||||
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
|
||||
# because all callers (get_provider_token, invalidate_user_provider_cache) run
|
||||
# exclusively on the asyncio event loop. There are no await points between a
|
||||
# cache read and its corresponding write within any function, so no concurrent
|
||||
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
|
||||
# this path, a threading.RLock should be wrapped around these caches.
|
||||
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
|
||||
)
|
||||
# Separate cache for "no credentials" results with a shorter TTL.
|
||||
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
|
||||
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
|
||||
"""Remove the cached entry for *user_id*/*provider* from both caches.
|
||||
|
||||
Call this after storing new credentials so that the next
|
||||
``get_provider_token()`` call performs a fresh DB lookup instead of
|
||||
serving a stale TTL-cached result.
|
||||
"""
|
||||
key = (user_id, provider)
|
||||
_token_cache.pop(key, None)
|
||||
_null_cache.pop(key, None)
|
||||
|
||||
|
||||
# Register this module's cache-bust function with the credentials manager so
|
||||
# that any create/update/delete operation immediately evicts stale cache
|
||||
# entries. This avoids a lazy import inside creds_manager and eliminates the
|
||||
# circular-import risk.
|
||||
register_creds_changed_hook(invalidate_user_provider_cache)
|
||||
|
||||
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
|
||||
# on every cache-miss call to get_provider_token().
|
||||
_manager = IntegrationCredentialsManager()
|
||||
|
||||
|
||||
async def get_provider_token(user_id: str, provider: str) -> str | None:
|
||||
"""Return the user's access token for *provider*, or ``None`` if not connected.
|
||||
|
||||
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
|
||||
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
|
||||
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
|
||||
command for users who haven't connected yet, while still picking up a
|
||||
newly-connected account within one minute.
|
||||
"""
|
||||
cache_key = (user_id, provider)
|
||||
|
||||
if cache_key in _null_cache:
|
||||
return None
|
||||
if cached := _token_cache.get(cache_key):
|
||||
return cached
|
||||
|
||||
manager = _manager
|
||||
try:
|
||||
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
|
||||
return None
|
||||
|
||||
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
|
||||
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
|
||||
# full git access, while a public-data-only token lacks push/pull permission.
|
||||
# lock=False — background injection; not worth a distributed lock acquisition.
|
||||
oauth2_creds = sorted(
|
||||
[c for c in creds_list if c.type == "oauth2"],
|
||||
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
|
||||
)
|
||||
for creds in oauth2_creds:
|
||||
if creds.type == "oauth2":
|
||||
try:
|
||||
fresh = await manager.refresh_if_needed(
|
||||
user_id, cast(OAuth2Credentials, creds), lock=False
|
||||
)
|
||||
token = fresh.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh %s OAuth token for user %s; "
|
||||
"falling back to potentially stale token",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# Pass 2: fall back to API key (no expiry, no refresh needed).
|
||||
for creds in creds_list:
|
||||
if creds.type == "api_key":
|
||||
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
|
||||
_token_cache[cache_key] = token
|
||||
return token
|
||||
|
||||
# No credentials found — cache to avoid repeated DB hits.
|
||||
_null_cache[cache_key] = True
|
||||
return None
|
||||
|
||||
|
||||
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
|
||||
"""Return env vars for all providers the user has connected.
|
||||
|
||||
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
|
||||
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
|
||||
Only providers with a stored credential contribute entries.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
for provider, var_names in PROVIDER_ENV_VARS.items():
|
||||
token = await get_provider_token(user_id, provider)
|
||||
if token:
|
||||
for var in var_names:
|
||||
env[var] = token
|
||||
return env
|
||||
@@ -1,193 +0,0 @@
|
||||
"""Tests for integration_creds — TTL cache and token lookup paths."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.copilot.integration_creds import (
|
||||
_NULL_CACHE_TTL,
|
||||
_TOKEN_CACHE_TTL,
|
||||
PROVIDER_ENV_VARS,
|
||||
_null_cache,
|
||||
_token_cache,
|
||||
get_integration_env_vars,
|
||||
get_provider_token,
|
||||
invalidate_user_provider_cache,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
|
||||
_USER = "user-integration-creds-test"
|
||||
_PROVIDER = "github"
|
||||
|
||||
|
||||
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="creds-api-key",
|
||||
provider=_PROVIDER,
|
||||
api_key=SecretStr(key),
|
||||
title="Test API Key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
|
||||
return OAuth2Credentials(
|
||||
id="creds-oauth2",
|
||||
provider=_PROVIDER,
|
||||
title="Test OAuth",
|
||||
access_token=SecretStr(token),
|
||||
refresh_token=SecretStr("test-refresh"),
|
||||
access_token_expires_at=None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches():
|
||||
"""Ensure clean caches before and after every test."""
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
yield
|
||||
_token_cache.clear()
|
||||
_null_cache.clear()
|
||||
|
||||
|
||||
class TestInvalidateUserProviderCache:
|
||||
def test_removes_token_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_token_cache[key] = "tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _token_cache
|
||||
|
||||
def test_removes_null_entry(self):
|
||||
key = (_USER, _PROVIDER)
|
||||
_null_cache[key] = True
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert key not in _null_cache
|
||||
|
||||
def test_noop_when_key_not_cached(self):
|
||||
# Should not raise even when there is no cache entry.
|
||||
invalidate_user_provider_cache("no-such-user", _PROVIDER)
|
||||
|
||||
def test_only_removes_targeted_key(self):
|
||||
other_key = ("other-user", _PROVIDER)
|
||||
_token_cache[other_key] = "other-tok"
|
||||
invalidate_user_provider_cache(_USER, _PROVIDER)
|
||||
assert other_key in _token_cache
|
||||
|
||||
|
||||
class TestGetProviderToken:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_cached_token_without_db_hit(self):
|
||||
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "cached-tok"
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_returns_none_for_null_cached_provider(self):
|
||||
_null_cache[(_USER, _PROVIDER)] = True
|
||||
|
||||
mock_manager = MagicMock()
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
mock_manager.store.get_creds_by_provider.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_api_key_creds_returned_and_cached(self):
|
||||
api_creds = _make_api_key_creds("my-api-key")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "my-api-key"
|
||||
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_preferred_over_api_key(self):
|
||||
oauth_creds = _make_oauth2_creds("oauth-tok")
|
||||
api_creds = _make_api_key_creds("api-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
return_value=[api_creds, oauth_creds]
|
||||
)
|
||||
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
|
||||
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
|
||||
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result == "stale-oauth-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_credentials_caches_null_entry(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
assert _null_cache.get((_USER, _PROVIDER)) is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_db_exception_returns_none_without_caching(self):
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.store.get_creds_by_provider = AsyncMock(
|
||||
side_effect=RuntimeError("db down")
|
||||
)
|
||||
|
||||
with patch("backend.copilot.integration_creds._manager", mock_manager):
|
||||
result = await get_provider_token(_USER, _PROVIDER)
|
||||
|
||||
assert result is None
|
||||
# DB errors are not cached — next call will retry
|
||||
assert (_USER, _PROVIDER) not in _token_cache
|
||||
assert (_USER, _PROVIDER) not in _null_cache
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
|
||||
"""Verify the TTL constants are set correctly for each cache."""
|
||||
assert _null_cache.ttl == _NULL_CACHE_TTL
|
||||
assert _token_cache.ttl == _TOKEN_CACHE_TTL
|
||||
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
|
||||
|
||||
|
||||
class TestGetIntegrationEnvVars:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_injects_all_env_vars_for_provider(self):
|
||||
_token_cache[(_USER, "github")] = "gh-tok"
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
for var in PROVIDER_ENV_VARS["github"]:
|
||||
assert result[var] == "gh-tok"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_dict_when_no_credentials(self):
|
||||
_null_cache[(_USER, "github")] = True
|
||||
|
||||
result = await get_integration_env_vars(_USER)
|
||||
|
||||
assert result == {}
|
||||
@@ -52,68 +52,17 @@ Examples:
|
||||
You can embed a reference inside any string argument, or use it as the entire
|
||||
value. Multiple references in one argument are all expanded.
|
||||
|
||||
**Structured data**: When the **entire** argument value is a single file
|
||||
reference (no surrounding text), the platform automatically parses the file
|
||||
content based on its extension or MIME type. Supported formats: JSON, JSONL,
|
||||
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
|
||||
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
|
||||
the rows will be parsed into `list[list[str]]` automatically. If the format is
|
||||
unrecognised or parsing fails, the content is returned as a plain string.
|
||||
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
|
||||
**Type coercion**: The platform automatically coerces expanded string values
|
||||
to match the block's expected input types. For example, if a block expects
|
||||
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
|
||||
an @@agptfile: expansion), the string will be parsed into the correct type.
|
||||
|
||||
**Type coercion**: The platform also coerces expanded values to match the
|
||||
block's expected input types. For example, if a block expects `list[list[str]]`
|
||||
and the expanded value is a JSON string, it will be parsed into the correct type.
|
||||
|
||||
### Media file inputs (format: "file")
|
||||
Some block inputs accept media files — their schema shows `"format": "file"`.
|
||||
These fields accept:
|
||||
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
|
||||
for large files (images, videos, PDFs). The platform passes the reference
|
||||
directly to the block without reading the content into memory.
|
||||
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
|
||||
small files only.
|
||||
|
||||
When a block input has `format: "file"`, **pass the `workspace://` URI
|
||||
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
|
||||
payloads in tool arguments and preserves binary content (images, videos)
|
||||
that would be corrupted by text encoding.
|
||||
|
||||
Example — committing an image file to GitHub:
|
||||
```json
|
||||
{
|
||||
"files": [{
|
||||
"path": "docs/hero.png",
|
||||
"content": "workspace://abc123#image/png",
|
||||
"operation": "upsert"
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# E2B-only notes — E2B has full internet access so gh CLI works there.
|
||||
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
|
||||
_E2B_TOOL_NOTES = """
|
||||
### GitHub CLI (`gh`) and git
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
- For operations that need broader access (e.g. private org repos, GitHub
|
||||
Actions), pass the required scopes: e.g.
|
||||
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
|
||||
"""
|
||||
|
||||
|
||||
# Environment-specific supplement templates
|
||||
def _build_storage_supplement(
|
||||
@@ -124,7 +73,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: list[str],
|
||||
file_move_name_1_to_2: str,
|
||||
file_move_name_2_to_1: str,
|
||||
extra_notes: str = "",
|
||||
) -> str:
|
||||
"""Build storage/filesystem supplement for a specific environment.
|
||||
|
||||
@@ -139,7 +87,6 @@ def _build_storage_supplement(
|
||||
storage_system_1_persistence: List of persistence behavior descriptions
|
||||
file_move_name_1_to_2: Direction label for primary→persistent
|
||||
file_move_name_2_to_1: Direction label for persistent→primary
|
||||
extra_notes: Environment-specific notes appended after shared notes
|
||||
"""
|
||||
# Format lists as bullet points with proper indentation
|
||||
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
|
||||
@@ -173,16 +120,12 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
{_SHARED_TOOL_NOTES}{extra_notes}"""
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
# Pre-built supplements for common environments
|
||||
def _get_local_storage_supplement(cwd: str) -> str:
|
||||
"""Local ephemeral storage (files lost between turns).
|
||||
|
||||
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
|
||||
like gh will not work — no integration env-var notes are included.
|
||||
"""
|
||||
"""Local ephemeral storage (files lost between turns)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir=cwd,
|
||||
sandbox_type="in a network-isolated sandbox",
|
||||
@@ -200,11 +143,7 @@ def _get_local_storage_supplement(cwd: str) -> str:
|
||||
|
||||
|
||||
def _get_cloud_sandbox_supplement() -> str:
|
||||
"""Cloud persistent sandbox (files survive across turns in session).
|
||||
|
||||
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
|
||||
injected per command in bash_exec — include the CLI guidance notes.
|
||||
"""
|
||||
"""Cloud persistent sandbox (files survive across turns in session)."""
|
||||
return _build_storage_supplement(
|
||||
working_dir="/home/user",
|
||||
sandbox_type="in a cloud sandbox with full internet access",
|
||||
@@ -219,7 +158,6 @@ def _get_cloud_sandbox_supplement() -> str:
|
||||
],
|
||||
file_move_name_1_to_2="Sandbox → Persistent",
|
||||
file_move_name_2_to_1="Persistent → Sandbox",
|
||||
extra_notes=_E2B_TOOL_NOTES,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,45 +3,12 @@
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
|
||||
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
|
||||
circular import cycle::
|
||||
|
||||
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
|
||||
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
|
||||
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
|
||||
|
||||
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
|
||||
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
|
||||
function scope without a larger refactor (moving tool-name registration
|
||||
to a separate lightweight module). The lazy-import pattern here is the
|
||||
least invasive way to break the cycle while keeping module-level constants
|
||||
intact.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
|
||||
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
|
||||
# pair so new exports can be added without touching __getattr__ itself.
|
||||
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
|
||||
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
entry = _LAZY_IMPORTS.get(name)
|
||||
if entry is not None:
|
||||
module_path, attr = entry
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path, package=__name__)
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -41,20 +41,12 @@ from typing import Any
|
||||
from backend.copilot.context import (
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.workspace_files import get_manager
|
||||
from backend.util.file import parse_workspace_uri
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
MIME_TO_FORMAT,
|
||||
PARSE_EXCEPTIONS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
class FileRefExpansionError(Exception):
|
||||
@@ -82,8 +74,6 @@ _FILE_REF_RE = re.compile(
|
||||
_MAX_EXPAND_CHARS = 200_000
|
||||
# Maximum total characters across all @@agptfile: expansions in one string.
|
||||
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
|
||||
# Maximum raw byte size for bare ref structured parsing (10 MB).
|
||||
_MAX_BARE_REF_BYTES = 10_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,11 +83,6 @@ class FileRef:
|
||||
end_line: int | None # 1-indexed, inclusive
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_file_ref(text: str) -> FileRef | None:
|
||||
"""Return a :class:`FileRef` if *text* is a bare file reference token.
|
||||
|
||||
@@ -119,6 +104,17 @@ def parse_file_ref(text: str) -> FileRef | None:
|
||||
return FileRef(uri=m.group(1), start_line=start, end_line=end)
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else len(lines)
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
return "".join(selected)
|
||||
|
||||
|
||||
async def read_file_bytes(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
@@ -134,47 +130,27 @@ async def read_file_bytes(
|
||||
if plain.startswith("workspace://"):
|
||||
if not user_id:
|
||||
raise ValueError("workspace:// file references require authentication")
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
ws = parse_workspace_uri(plain)
|
||||
try:
|
||||
data = await (
|
||||
return await (
|
||||
manager.read_file(ws.file_ref)
|
||||
if ws.is_path
|
||||
else manager.read_file_by_id(ws.file_ref)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except (PermissionError, OSError) as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
except (AttributeError, TypeError, RuntimeError) as exc:
|
||||
# AttributeError/TypeError: workspace manager returned an
|
||||
# unexpected type or interface; RuntimeError: async runtime issues.
|
||||
logger.warning("Unexpected error reading %s: %s", plain, exc)
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
# NOTE: Workspace API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
|
||||
if is_allowed_local_path(plain, get_sdk_cwd()):
|
||||
resolved = os.path.realpath(os.path.expanduser(plain))
|
||||
try:
|
||||
# Read with a one-byte overshoot to detect files that exceed the limit
|
||||
# without a separate os.path.getsize call (avoids TOCTOU race).
|
||||
with open(resolved, "rb") as fh:
|
||||
data = fh.read(_MAX_BARE_REF_BYTES + 1)
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
|
||||
f"limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
return fh.read()
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"File not found: {plain}")
|
||||
except OSError as exc:
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Failed to read {plain}: {exc}") from exc
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
@@ -186,33 +162,9 @@ async def read_file_bytes(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
) from exc
|
||||
try:
|
||||
data = bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
return bytes(await sandbox.files.read(remote, format="bytes"))
|
||||
except Exception as exc:
|
||||
# E2B SDK raises SandboxException subclasses (NotFoundException,
|
||||
# TimeoutException, NotEnoughSpaceException, etc.) which don't
|
||||
# inherit from standard exceptions. Import lazily to avoid a
|
||||
# hard dependency on e2b at module level.
|
||||
try:
|
||||
from e2b.exceptions import SandboxException # noqa: PLC0415
|
||||
|
||||
if isinstance(exc, SandboxException):
|
||||
raise ValueError(
|
||||
f"Failed to read from sandbox: {plain}: {exc}"
|
||||
) from exc
|
||||
except ImportError:
|
||||
pass
|
||||
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
|
||||
# so they surface as real bugs rather than being silently masked.
|
||||
raise
|
||||
# NOTE: E2B sandbox API does not support pre-read size checks;
|
||||
# the full file is loaded before the size guard below.
|
||||
if len(data) > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
return data
|
||||
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
|
||||
@@ -226,13 +178,15 @@ async def resolve_file_ref(
|
||||
) -> str:
|
||||
"""Resolve a :class:`FileRef` to its text content."""
|
||||
raw = await read_file_bytes(ref.uri, user_id, session)
|
||||
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
|
||||
return _apply_line_range(
|
||||
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
|
||||
)
|
||||
|
||||
|
||||
async def expand_file_refs_in_string(
|
||||
text: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
session: "ChatSession",
|
||||
*,
|
||||
raise_on_error: bool = False,
|
||||
) -> str:
|
||||
@@ -278,9 +232,6 @@ async def expand_file_refs_in_string(
|
||||
if len(content) > _MAX_EXPAND_CHARS:
|
||||
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
|
||||
# remaining == 0 means the budget was exactly exhausted by the
|
||||
# previous ref. The elif below (len > remaining) won't catch
|
||||
# this since 0 > 0 is false, so we need the <= 0 check.
|
||||
if remaining <= 0:
|
||||
content = "[file-ref budget exhausted: total expansion limit reached]"
|
||||
elif len(content) > remaining:
|
||||
@@ -301,31 +252,13 @@ async def expand_file_refs_in_string(
|
||||
async def expand_file_refs_in_args(
|
||||
args: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
*,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
session: "ChatSession",
|
||||
) -> dict[str, Any]:
|
||||
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
|
||||
|
||||
String values are expanded in-place. Nested dicts and lists are
|
||||
traversed. Non-string scalars are returned unchanged.
|
||||
|
||||
**Bare references** (the entire argument value is a single
|
||||
``@@agptfile:...`` token with no surrounding text) are resolved and then
|
||||
parsed according to the file's extension or MIME type. See
|
||||
:mod:`backend.util.file_content_parser` for the full list of supported
|
||||
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
|
||||
|
||||
When *input_schema* is provided and the target property has
|
||||
``"type": "string"``, structured parsing is skipped — the raw file content
|
||||
is returned as a plain string so blocks receive the original text.
|
||||
|
||||
If the format is unrecognised or parsing fails, the content is returned as
|
||||
a plain string (the fallback).
|
||||
|
||||
**Embedded references** (``@@agptfile:`` mixed with other text) always
|
||||
produce a plain string — structured parsing only applies to bare refs.
|
||||
|
||||
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
|
||||
so the tool is *not* executed with an error string as its input. The
|
||||
caller (the MCP tool wrapper) should convert this into an MCP error
|
||||
@@ -334,382 +267,15 @@ async def expand_file_refs_in_args(
|
||||
if not args:
|
||||
return args
|
||||
|
||||
properties = (input_schema or {}).get("properties", {})
|
||||
|
||||
async def _expand(
|
||||
value: Any,
|
||||
*,
|
||||
prop_schema: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Recursively expand a single argument value.
|
||||
|
||||
Strings are checked for ``@@agptfile:`` references and expanded
|
||||
(bare refs get structured parsing; embedded refs get inline
|
||||
substitution). Dicts and lists are traversed recursively,
|
||||
threading the corresponding sub-schema from *prop_schema* so
|
||||
that nested fields also receive correct type-aware expansion.
|
||||
Non-string scalars pass through unchanged.
|
||||
"""
|
||||
async def _expand(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
ref = parse_file_ref(value)
|
||||
if ref is not None:
|
||||
# MediaFileType fields: return the raw URI immediately —
|
||||
# no file reading, no format inference, no content parsing.
|
||||
if _is_media_file_field(prop_schema):
|
||||
return ref.uri
|
||||
|
||||
fmt = infer_format_from_uri(ref.uri)
|
||||
# Workspace URIs by ID (workspace://abc123) have no extension.
|
||||
# When the MIME fragment is also missing, fall back to the
|
||||
# workspace file manager's metadata for format detection.
|
||||
if fmt is None and ref.uri.startswith("workspace://"):
|
||||
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
|
||||
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
|
||||
|
||||
# Not a bare ref — do normal inline expansion.
|
||||
return await expand_file_refs_in_string(
|
||||
value, user_id, session, raise_on_error=True
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
# When the schema says this is an object but doesn't define
|
||||
# inner properties, skip expansion — the caller (e.g.
|
||||
# RunBlockTool) will expand with the actual nested schema.
|
||||
if (
|
||||
prop_schema is not None
|
||||
and prop_schema.get("type") == "object"
|
||||
and "properties" not in prop_schema
|
||||
):
|
||||
return value
|
||||
nested_props = (prop_schema or {}).get("properties", {})
|
||||
return {
|
||||
k: await _expand(v, prop_schema=nested_props.get(k))
|
||||
for k, v in value.items()
|
||||
}
|
||||
return {k: await _expand(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
items_schema = (prop_schema or {}).get("items")
|
||||
return [await _expand(item, prop_schema=items_schema) for item in value]
|
||||
return [await _expand(item) for item in value]
|
||||
return value
|
||||
|
||||
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers (used by the public functions above)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
|
||||
"""Slice *text* to the requested 1-indexed line range (inclusive).
|
||||
|
||||
When the requested range extends beyond the file, a note is appended
|
||||
so the LLM knows it received the entire remaining content.
|
||||
"""
|
||||
if start is None and end is None:
|
||||
return text
|
||||
lines = text.splitlines(keepends=True)
|
||||
total = len(lines)
|
||||
s = (start - 1) if start is not None else 0
|
||||
e = end if end is not None else total
|
||||
selected = list(itertools.islice(lines, s, e))
|
||||
result = "".join(selected)
|
||||
if end is not None and end > total:
|
||||
result += f"\n[Note: file has only {total} lines]\n"
|
||||
return result
|
||||
|
||||
|
||||
def _to_str(content: str | bytes) -> str:
|
||||
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _check_content_size(content: str | bytes) -> None:
|
||||
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
|
||||
|
||||
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
|
||||
(``_expand_bare_ref``) can unify all resolution errors into a single
|
||||
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
|
||||
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
|
||||
|
||||
For ``bytes``, the length is the byte count directly. For ``str``,
|
||||
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
|
||||
mean the byte size can be up to 4x the character count.
|
||||
"""
|
||||
if isinstance(content, bytes):
|
||||
size = len(content)
|
||||
else:
|
||||
char_len = len(content)
|
||||
# Fast lower bound: UTF-8 byte count >= char count.
|
||||
# If char count already exceeds the limit, reject immediately
|
||||
# without allocating an encoded copy.
|
||||
if char_len > _MAX_BARE_REF_BYTES:
|
||||
size = char_len # real byte size is even larger
|
||||
# Fast upper bound: each char is at most 4 UTF-8 bytes.
|
||||
# If worst-case is still under the limit, skip encoding entirely.
|
||||
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
|
||||
return
|
||||
else:
|
||||
# Edge case: char count is under limit but multibyte chars
|
||||
# might push byte count over. Encode to get exact size.
|
||||
size = len(content.encode("utf-8"))
|
||||
if size > _MAX_BARE_REF_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large for structured parsing "
|
||||
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
|
||||
)
|
||||
|
||||
|
||||
async def _infer_format_from_workspace(
|
||||
uri: str,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
) -> str | None:
|
||||
"""Look up workspace file metadata to infer the format.
|
||||
|
||||
Workspace URIs by ID (``workspace://abc123``) have no file extension.
|
||||
When the MIME fragment is also absent, we query the workspace file
|
||||
manager for the file's stored MIME type and original filename.
|
||||
"""
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
ws = parse_workspace_uri(uri)
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
info = await (
|
||||
manager.get_file_info(ws.file_ref)
|
||||
if not ws.is_path
|
||||
else manager.get_file_info_by_path(ws.file_ref)
|
||||
)
|
||||
if info is None:
|
||||
return None
|
||||
# Try MIME type first, then filename extension.
|
||||
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
|
||||
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
|
||||
except (
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
OSError,
|
||||
PermissionError,
|
||||
AttributeError,
|
||||
TypeError,
|
||||
):
|
||||
# Expected failures: bad URI, missing file, permission denied, or
|
||||
# workspace manager returning unexpected types. Propagate anything
|
||||
# else (e.g. programming errors) so they don't get silently swallowed.
|
||||
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
|
||||
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
|
||||
if prop_schema is None:
|
||||
return False
|
||||
return (
|
||||
prop_schema.get("type") == "string"
|
||||
and prop_schema.get("format") == MediaFileType.string_format
|
||||
)
|
||||
|
||||
|
||||
async def _expand_bare_ref(
|
||||
ref: FileRef,
|
||||
fmt: str | None,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
prop_schema: dict[str, Any] | None,
|
||||
) -> Any:
|
||||
"""Resolve and parse a bare ``@@agptfile:`` reference.
|
||||
|
||||
This is the structured-parsing path: the file is read, optionally parsed
|
||||
according to *fmt*, and adapted to the target *prop_schema*.
|
||||
|
||||
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
|
||||
|
||||
Note: MediaFileType fields (format: "file") are handled earlier in
|
||||
``_expand`` to avoid unnecessary format inference and file I/O.
|
||||
"""
|
||||
try:
|
||||
if fmt is not None and fmt in BINARY_FORMATS:
|
||||
# Binary formats need raw bytes, not UTF-8 text.
|
||||
# Line ranges are meaningless for binary formats (parquet/xlsx)
|
||||
# — ignore them and parse full bytes. Warn so the caller/model
|
||||
# knows the range was silently dropped.
|
||||
if ref.start_line is not None or ref.end_line is not None:
|
||||
logger.warning(
|
||||
"Line range [%s-%s] ignored for binary format %s (%s); "
|
||||
"binary formats are always parsed in full.",
|
||||
ref.start_line,
|
||||
ref.end_line,
|
||||
fmt,
|
||||
ref.uri,
|
||||
)
|
||||
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
|
||||
else:
|
||||
content = await resolve_file_ref(ref, user_id, session)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# For known formats this rejects files >10 MB before parsing.
|
||||
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
|
||||
# but this check still guards the parsing path which has no char limit.
|
||||
# _check_content_size raises ValueError, which we unify here just like
|
||||
# resolution errors above.
|
||||
try:
|
||||
_check_content_size(content)
|
||||
except ValueError as exc:
|
||||
raise FileRefExpansionError(str(exc)) from exc
|
||||
|
||||
# When the schema declares this parameter as "string",
|
||||
# return raw file content — don't parse into a structured
|
||||
# type that would need json.dumps() serialisation.
|
||||
expect_string = (prop_schema or {}).get("type") == "string"
|
||||
if expect_string:
|
||||
if isinstance(content, bytes):
|
||||
raise FileRefExpansionError(
|
||||
f"Cannot use {fmt} file as text input: "
|
||||
f"binary formats (parquet, xlsx) must be passed "
|
||||
f"to a block that accepts structured data (list/object), "
|
||||
f"not a string-typed parameter."
|
||||
)
|
||||
return content
|
||||
|
||||
if fmt is not None:
|
||||
# Use strict mode for binary formats so we surface the
|
||||
# actual error (e.g. missing pyarrow/openpyxl, corrupt
|
||||
# file) instead of silently returning garbled bytes.
|
||||
strict = fmt in BINARY_FORMATS
|
||||
try:
|
||||
parsed = parse_file_content(content, fmt, strict=strict)
|
||||
except PARSE_EXCEPTIONS as exc:
|
||||
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
|
||||
# Normalize bytes fallback to str so tools never
|
||||
# receive raw bytes when parsing fails.
|
||||
if isinstance(parsed, bytes):
|
||||
parsed = _to_str(parsed)
|
||||
return _adapt_to_schema(parsed, prop_schema)
|
||||
|
||||
# Unknown format — return as plain string, but apply
|
||||
# the same per-ref character limit used by inline refs
|
||||
# to prevent injecting unexpectedly large content.
|
||||
text = _to_str(content)
|
||||
if len(text) > _MAX_EXPAND_CHARS:
|
||||
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
|
||||
return text
|
||||
|
||||
|
||||
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
|
||||
"""Adapt a parsed file value to better fit the target schema type.
|
||||
|
||||
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
|
||||
that doesn't match the block's expected type, this function converts it to
|
||||
a more useful representation instead of relying on pydantic's generic
|
||||
coercion (which can produce awkward results like flattened dicts → lists).
|
||||
|
||||
Returns *parsed* unchanged when no adaptation is needed.
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return parsed
|
||||
|
||||
target_type = prop_schema.get("type")
|
||||
|
||||
# Dict → array: delegate to helper.
|
||||
if isinstance(parsed, dict) and target_type == "array":
|
||||
return _adapt_dict_to_array(parsed, prop_schema)
|
||||
|
||||
# List → object: delegate to helper (raises for non-tabular lists).
|
||||
if isinstance(parsed, list) and target_type == "object":
|
||||
return _adapt_list_to_object(parsed)
|
||||
|
||||
# Tabular list → Any (no type): convert to list of dicts.
|
||||
# Blocks like FindInDictionaryBlock have `input: Any` which produces
|
||||
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
|
||||
# for key lookup, but [{col: val}, ...] works with FindInDict's
|
||||
# list-of-dicts branch (line 195-199 in data_manipulation.py).
|
||||
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
|
||||
return _tabular_to_list_of_dicts(parsed)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
|
||||
"""Adapt a parsed dict to an array-typed field.
|
||||
|
||||
Extracts list-valued entries when the target item type is ``array``,
|
||||
passes through unchanged when item type is ``string`` (lets pydantic error),
|
||||
or wraps in ``[parsed]`` as a fallback.
|
||||
"""
|
||||
items_type = (prop_schema.get("items") or {}).get("type")
|
||||
if items_type == "array":
|
||||
# Target is List[List[Any]] — extract list-typed values from the
|
||||
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
|
||||
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
|
||||
list_values = [v for v in parsed.values() if isinstance(v, list)]
|
||||
if list_values:
|
||||
return list_values
|
||||
if items_type == "string":
|
||||
# Target is List[str] — wrapping a dict would give [dict]
|
||||
# which can't coerce to strings. Return unchanged and let
|
||||
# pydantic surface a clear validation error.
|
||||
return parsed
|
||||
# Fallback: wrap in a single-element list so the block gets [dict]
|
||||
# instead of pydantic flattening keys/values into a flat list.
|
||||
return [parsed]
|
||||
|
||||
|
||||
def _adapt_list_to_object(parsed: list) -> Any:
|
||||
"""Adapt a parsed list to an object-typed field.
|
||||
|
||||
Converts tabular lists to column-dicts; raises for non-tabular lists.
|
||||
"""
|
||||
if _is_tabular(parsed):
|
||||
return _tabular_to_column_dict(parsed)
|
||||
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
|
||||
# be meaningfully coerced to an object. Raise explicitly so callers
|
||||
# get a clear error rather than pydantic silently wrapping the list.
|
||||
raise FileRefExpansionError(
|
||||
"Cannot adapt a non-tabular list to an object-typed field. "
|
||||
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
|
||||
)
|
||||
|
||||
|
||||
def _is_tabular(parsed: Any) -> bool:
|
||||
"""Check if parsed data is in tabular format: [[header], [row1], ...].
|
||||
|
||||
Uses isinstance checks because this is a structural type guard on
|
||||
opaque parser output (Any), not duck typing. A Protocol wouldn't
|
||||
help here — we need to verify exact list-of-lists shape.
|
||||
"""
|
||||
if not isinstance(parsed, list) or len(parsed) < 2:
|
||||
return False
|
||||
header = parsed[0]
|
||||
if not isinstance(header, list) or not header:
|
||||
return False
|
||||
if not all(isinstance(h, str) for h in header):
|
||||
return False
|
||||
return all(isinstance(row, list) for row in parsed[1:])
|
||||
|
||||
|
||||
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
|
||||
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values.
|
||||
Extra values beyond the header length are silently dropped.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return [
|
||||
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
|
||||
for row in parsed[1:]
|
||||
]
|
||||
|
||||
|
||||
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
|
||||
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
|
||||
|
||||
Ragged rows (fewer columns than the header) get None for missing values,
|
||||
ensuring all columns have equal length.
|
||||
"""
|
||||
header = parsed[0]
|
||||
return {
|
||||
col: [row[i] if i < len(row) else None for row in parsed[1:]]
|
||||
for i, col in enumerate(header)
|
||||
}
|
||||
return {k: await _expand(v) for k, v in args.items()}
|
||||
|
||||
@@ -175,199 +175,6 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
|
||||
assert result["count"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# expand_file_refs_in_args — bare ref structured parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_json_returns_parsed_dict():
|
||||
"""Bare ref to a .json file returns parsed dict, not raw string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value", "count": 42}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_csv_returns_parsed_table():
|
||||
"""Bare ref to a .csv file returns list[list[str]] table."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
csv_file = os.path.join(sdk_cwd, "data.csv")
|
||||
with open(csv_file, "w") as f:
|
||||
f.write("Name,Score\nAlice,90\nBob,85")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"input": f"@@agptfile:{csv_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["input"] == [
|
||||
["Name", "Score"],
|
||||
["Alice", "90"],
|
||||
["Bob", "85"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_unknown_extension_returns_string():
|
||||
"""Bare ref to a file with unknown extension returns plain string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
txt_file = os.path.join(sdk_cwd, "readme.txt")
|
||||
with open(txt_file, "w") as f:
|
||||
f.write("plain text content")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{txt_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "plain text content"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_invalid_json_falls_back_to_string():
|
||||
"""Bare ref to a .json file with invalid JSON falls back to string."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "bad.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write("not valid json {{{")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{json_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["data"] == "not valid json {{{"
|
||||
assert isinstance(result["data"], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedded_ref_always_returns_string_even_for_json():
|
||||
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
json_file = os.path.join(sdk_cwd, "data.json")
|
||||
with open(json_file, "w") as f:
|
||||
f.write('{"key": "value"}')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"prefix @@agptfile:{json_file} suffix"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert isinstance(result["data"], str)
|
||||
assert result["data"].startswith("prefix ")
|
||||
assert result["data"].endswith(" suffix")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_yaml_returns_parsed_dict():
|
||||
"""Bare ref to a .yaml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
yaml_file = os.path.join(sdk_cwd, "config.yaml")
|
||||
with open(yaml_file, "w") as f:
|
||||
f.write("name: test\ncount: 42\n")
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{yaml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_binary_with_line_range_ignores_range():
|
||||
"""Bare ref to a binary file (.parquet) with line range parses the full file.
|
||||
|
||||
Binary formats (parquet, xlsx) ignore line ranges — the full content is
|
||||
parsed and the range is silently dropped with a log warning.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pytest.skip("pandas not installed")
|
||||
try:
|
||||
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
except ImportError:
|
||||
pytest.skip("pyarrow not installed")
|
||||
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
parquet_file = os.path.join(sdk_cwd, "data.parquet")
|
||||
import io as _io
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = _io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
with open(parquet_file, "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
# Line range [1-2] should be silently ignored for binary formats.
|
||||
result = await expand_file_refs_in_args(
|
||||
{"data": f"@@agptfile:{parquet_file}[1-2]"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
# Full file is returned despite the line range.
|
||||
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bare_ref_toml_returns_parsed_dict():
|
||||
"""Bare ref to a .toml file returns parsed dict."""
|
||||
with tempfile.TemporaryDirectory() as sdk_cwd:
|
||||
toml_file = os.path.join(sdk_cwd, "config.toml")
|
||||
with open(toml_file, "w") as f:
|
||||
f.write('name = "test"\ncount = 42\n')
|
||||
|
||||
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
|
||||
mock_cwd_var.get.return_value = sdk_cwd
|
||||
|
||||
result = await expand_file_refs_in_args(
|
||||
{"config": f"@@agptfile:{toml_file}"},
|
||||
user_id="u1",
|
||||
session=_make_session(),
|
||||
)
|
||||
|
||||
assert result["config"] == {"name": "test", "count": 42}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_file_handler — extended to accept workspace:// and local paths
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -412,7 +219,7 @@ async def test_read_file_handler_workspace_uri():
|
||||
"backend.copilot.sdk.tool_adapter.get_execution_context",
|
||||
return_value=("user-1", mock_session),
|
||||
), patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await _read_file_handler(
|
||||
@@ -469,7 +276,7 @@ async def test_read_file_bytes_workspace_virtual_path():
|
||||
mock_manager.read_file.return_value = b"virtual path content"
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.file_ref.get_workspace_manager",
|
||||
"backend.copilot.sdk.file_ref.get_manager",
|
||||
new=AsyncMock(return_value=mock_manager),
|
||||
):
|
||||
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,6 @@ from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.executor.cluster_lock import AsyncClusterLock
|
||||
from backend.util.exceptions import NotFoundError
|
||||
@@ -63,6 +62,7 @@ from ..service import (
|
||||
)
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tools.workspace_files import get_manager
|
||||
from ..tracking import track_user_message
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
@@ -565,7 +565,7 @@ async def _prepare_file_attachments(
|
||||
return empty
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create workspace manager for file attachments",
|
||||
@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
return None
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(
|
||||
return await get_or_create_sandbox(
|
||||
session_id,
|
||||
api_key=e2b_api_key,
|
||||
template=config.e2b_sandbox_template,
|
||||
@@ -783,9 +783,7 @@ async def stream_chat_completion_sdk(
|
||||
e2b_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return sandbox
|
||||
return None
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
|
||||
@@ -20,7 +20,7 @@ class _FakeFileInfo:
|
||||
size_bytes: int
|
||||
|
||||
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
|
||||
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
|
||||
|
||||
|
||||
class TestPrepareFileAttachments:
|
||||
|
||||
@@ -347,7 +347,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
:func:`get_sdk_disallowed_tools`.
|
||||
"""
|
||||
|
||||
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
|
||||
def _truncating(fn, tool_name: str):
|
||||
"""Wrap a tool handler so its response is truncated to stay under the
|
||||
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
|
||||
response adapter before the SDK can apply its own head-truncation.
|
||||
@@ -361,9 +361,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
user_id, session = get_execution_context()
|
||||
if session is not None:
|
||||
try:
|
||||
args = await expand_file_refs_in_args(
|
||||
args, user_id, session, input_schema=input_schema
|
||||
)
|
||||
args = await expand_file_refs_in_args(args, user_id, session)
|
||||
except FileRefExpansionError as exc:
|
||||
return _mcp_error(
|
||||
f"@@agptfile: reference could not be resolved: {exc}. "
|
||||
@@ -391,12 +389,11 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handler = create_tool_handler(base_tool)
|
||||
schema = _build_input_schema(base_tool)
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
schema,
|
||||
)(_truncating(handler, tool_name, input_schema=schema))
|
||||
_build_input_schema(base_tool),
|
||||
)(_truncating(handler, tool_name))
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
|
||||
|
||||
@@ -12,7 +12,6 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .continue_run_block import ContinueRunBlockTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
@@ -85,7 +84,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"browser_screenshot": BrowserScreenshotTool(),
|
||||
# Sandboxed code execution (bubblewrap)
|
||||
"bash_exec": BashExecTool(),
|
||||
"connect_integration": ConnectIntegrationTool(),
|
||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
|
||||
@@ -32,7 +32,6 @@ import shutil
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.context import get_workspace_manager
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.util.request import validate_url_host
|
||||
|
||||
@@ -44,6 +43,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .workspace_files import get_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +194,7 @@ async def _save_browser_state(
|
||||
),
|
||||
}
|
||||
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
await manager.write_file(
|
||||
content=json.dumps(state).encode("utf-8"),
|
||||
filename=_STATE_FILENAME,
|
||||
@@ -218,7 +218,7 @@ async def _restore_browser_state(
|
||||
Returns True on success (or no state to restore), False on failure.
|
||||
"""
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session.session_id)
|
||||
manager = await get_manager(user_id, session.session_id)
|
||||
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is None:
|
||||
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
|
||||
# Delete persisted browser state (cookies, localStorage) from workspace.
|
||||
if user_id:
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_name)
|
||||
manager = await get_manager(user_id, session_name)
|
||||
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
|
||||
if file_info is not None:
|
||||
await manager.delete_file(file_info.id)
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestHasLocalSession:
|
||||
# _save_browser_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
|
||||
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
|
||||
|
||||
|
||||
def _make_mock_manager():
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
@@ -19,16 +18,12 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import is_creator_slug, is_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Keywords that should be treated as "list all" rather than a literal search
|
||||
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
|
||||
|
||||
@@ -39,149 +34,158 @@ async def search_agents(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
Search for agents in marketplace or user library.
|
||||
"""Search for agents in marketplace or user library."""
|
||||
if source == "marketplace":
|
||||
return await _search_marketplace(query, session_id)
|
||||
else:
|
||||
return await _search_library(query, session_id, user_id)
|
||||
|
||||
For library searches, keywords like "all", "*", "everything", or an empty
|
||||
query will list all agents without filtering.
|
||||
|
||||
Args:
|
||||
query: Search query string. Special keywords list all library agents.
|
||||
source: "marketplace" or "library"
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (required for library search)
|
||||
|
||||
Returns:
|
||||
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
|
||||
"""
|
||||
# Normalize list-all keywords to empty string for library searches
|
||||
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
if source == "marketplace" and not query:
|
||||
async def _search_marketplace(query: str, session_id: str | None) -> ToolResponseBase:
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query", session_id=session_id
|
||||
)
|
||||
|
||||
if source == "library" and not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if source == "marketplace":
|
||||
# Direct lookup if query matches "creator/slug" pattern
|
||||
if is_creator_slug(query):
|
||||
logger.info(f"Query looks like creator/slug, trying direct lookup: {query}")
|
||||
creator, slug = query.split("/", 1)
|
||||
agent_info = await _get_marketplace_agent_by_slug(creator, slug)
|
||||
if agent_info:
|
||||
agents.append(agent_info)
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
search_term = query or None
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=search_term,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
agents.append(_marketplace_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching {source}: {e}", exc_info=True)
|
||||
logger.error(f"Error searching marketplace: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to search {source}. Please try again.",
|
||||
message="Failed to search marketplace. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if source == "marketplace":
|
||||
suggestions = [
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
]
|
||||
no_results_msg = (
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents found matching '{query}'. Let the user know they can "
|
||||
"try different keywords or browse the marketplace. Also let them "
|
||||
"know you can create a custom agent for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=(
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
),
|
||||
title=f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'",
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
async def _search_library(
|
||||
query: str, session_id: str | None, user_id: str | None
|
||||
) -> ToolResponseBase:
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="User authentication required to search library",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
query = query.strip()
|
||||
# Normalize list-all keywords to empty string
|
||||
if query.lower() in _LIST_ALL_KEYWORDS:
|
||||
query = ""
|
||||
|
||||
agents: list[AgentInfo] = []
|
||||
try:
|
||||
if is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query)
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
|
||||
if not agents:
|
||||
logger.info(
|
||||
f"{'Listing all agents in' if not query else 'Searching'} "
|
||||
f"user library{'' if not query else f' for: {query}'}"
|
||||
)
|
||||
elif not query:
|
||||
# User asked to list all but library is empty
|
||||
suggestions = [
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
]
|
||||
no_results_msg = (
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=query or None,
|
||||
page_size=50 if not query else 10,
|
||||
)
|
||||
else:
|
||||
suggestions = [
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
]
|
||||
no_results_msg = (
|
||||
for agent in results.agents:
|
||||
agents.append(_library_agent_to_info(agent))
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Error searching library: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search library. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agents:
|
||||
if not query:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
"Your library is empty. Let the user know they can browse the "
|
||||
"marketplace to find agents, or you can create a custom agent "
|
||||
"for them based on their needs."
|
||||
),
|
||||
suggestions=[
|
||||
"Browse the marketplace to find and add agents",
|
||||
"Use find_agent to search the marketplace",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"No agents matching '{query}' found in your library. Let the "
|
||||
"user know you can create a custom agent for them based on "
|
||||
"their needs."
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||
),
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use find_agent to search the marketplace",
|
||||
"Check your library at /library",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source == "marketplace":
|
||||
title = (
|
||||
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
)
|
||||
elif not query:
|
||||
if not query:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
|
||||
else:
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
|
||||
|
||||
message = (
|
||||
"Now you have found some options for the user to choose from. "
|
||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||
"Please ask the user if they would like to use any of these agents. "
|
||||
"Let the user know we can create a custom agent for them based on their needs."
|
||||
if source == "marketplace"
|
||||
else "Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
)
|
||||
|
||||
return AgentsFoundResponse(
|
||||
message=message,
|
||||
message=(
|
||||
"Found agents in the user's library. You can provide a link to view "
|
||||
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
|
||||
"execution results, or run_agent to execute. Let the user know we can "
|
||||
"create a custom agent for them based on their needs."
|
||||
),
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
@@ -189,9 +193,20 @@ async def search_agents(
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
def _marketplace_agent_to_info(agent: Any) -> AgentInfo:
|
||||
"""Convert a marketplace agent (StoreAgent or StoreAgentDetails) to an AgentInfo."""
|
||||
return AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general",
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
|
||||
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
@@ -214,6 +229,23 @@ def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
|
||||
)
|
||||
|
||||
|
||||
async def _get_marketplace_agent_by_slug(creator: str, slug: str) -> AgentInfo | None:
|
||||
"""Fetch a marketplace agent by creator/slug identifier."""
|
||||
try:
|
||||
details = await store_db().get_store_agent_details(creator, slug)
|
||||
return _marketplace_agent_to_info(details)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch marketplace agent {creator}/{slug}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
@@ -226,10 +258,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by graph_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -241,10 +272,9 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return _library_agent_to_info(agent)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
pass
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Tests for agent search direct lookup functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .models import AgentsFoundResponse, NoResultsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-agent-search"
|
||||
|
||||
|
||||
class TestMarketplaceSlugLookup:
|
||||
"""Tests for creator/slug direct lookup in marketplace search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_found(self):
|
||||
"""creator/slug query returns the agent directly."""
|
||||
mock_details = MagicMock()
|
||||
mock_details.creator = "testuser"
|
||||
mock_details.slug = "my-agent"
|
||||
mock_details.agent_name = "My Agent"
|
||||
mock_details.description = "A test agent"
|
||||
mock_details.rating = 4.5
|
||||
mock_details.runs = 100
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(return_value=mock_details)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "testuser/my-agent"
|
||||
assert response.agents[0].name == "My Agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_falls_back_to_search(self):
|
||||
"""creator/slug not found falls back to general search."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
|
||||
# Fallback search returns results
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "other"
|
||||
mock_agent.slug = "similar-agent"
|
||||
mock_agent.agent_name = "Similar Agent"
|
||||
mock_agent.description = "A similar agent"
|
||||
mock_agent.rating = 3.0
|
||||
mock_agent.runs = 50
|
||||
mock_search_results.agents = [mock_agent]
|
||||
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/my-agent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].id == "other/similar-agent"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_slug_lookup_not_found_no_search_results(self):
|
||||
"""creator/slug not found and search returns nothing."""
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_store_agent_details = AsyncMock(side_effect=NotFoundError(""))
|
||||
mock_search_results = MagicMock()
|
||||
mock_search_results.agents = []
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="testuser/nonexistent",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_non_slug_query_goes_to_search(self):
|
||||
"""Regular keyword query skips slug lookup and goes to search."""
|
||||
mock_store = MagicMock()
|
||||
mock_search_results = MagicMock()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.creator = "creator1"
|
||||
mock_agent.slug = "email-agent"
|
||||
mock_agent.agent_name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.rating = 4.0
|
||||
mock_agent.runs = 200
|
||||
mock_search_results.agents = [mock_agent]
|
||||
mock_store.get_store_agents = AsyncMock(return_value=mock_search_results)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.store_db",
|
||||
return_value=mock_store,
|
||||
):
|
||||
response = await search_agents(
|
||||
query="email",
|
||||
source="marketplace",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
# get_store_agent_details should NOT have been called
|
||||
mock_store.get_store_agent_details.assert_not_called()
|
||||
|
||||
|
||||
class TestLibraryUUIDLookup:
|
||||
"""Tests for UUID direct lookup in library search."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found_by_graph_id(self):
|
||||
"""UUID query matching a graph_id returns the agent directly."""
|
||||
agent_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = "lib-agent-id"
|
||||
mock_agent.name = "My Library Agent"
|
||||
mock_agent.description = "A library agent"
|
||||
mock_agent.creator_name = "testuser"
|
||||
mock_agent.status.value = "HEALTHY"
|
||||
mock_agent.can_access_graph = True
|
||||
mock_agent.has_external_trigger = False
|
||||
mock_agent.new_output = False
|
||||
mock_agent.graph_id = agent_id
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_lib_db = MagicMock()
|
||||
mock_lib_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.agent_search.library_db",
|
||||
return_value=mock_lib_db,
|
||||
):
|
||||
response = await search_agents(
|
||||
query=agent_id,
|
||||
source="library",
|
||||
session_id="test-session",
|
||||
user_id=_TEST_USER_ID,
|
||||
)
|
||||
|
||||
assert isinstance(response, AgentsFoundResponse)
|
||||
assert response.count == 1
|
||||
assert response.agents[0].name == "My Library Agent"
|
||||
@@ -22,7 +22,6 @@ from e2b import AsyncSandbox
|
||||
from e2b.exceptions import TimeoutException
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
|
||||
from backend.copilot.integration_creds import get_integration_env_vars
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -97,9 +96,7 @@ class BashExecTool(BaseTool):
|
||||
|
||||
sandbox = get_current_sandbox()
|
||||
if sandbox is not None:
|
||||
return await self._execute_on_e2b(
|
||||
sandbox, command, timeout, session_id, user_id
|
||||
)
|
||||
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
|
||||
|
||||
# Bubblewrap fallback: local isolated execution.
|
||||
if not has_full_sandbox():
|
||||
@@ -136,27 +133,14 @@ class BashExecTool(BaseTool):
|
||||
command: str,
|
||||
timeout: int,
|
||||
session_id: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute *command* on the E2B sandbox via commands.run().
|
||||
|
||||
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
|
||||
for any user with connected accounts. E2B has full internet access, so
|
||||
CLI tools like ``gh`` work without manual authentication.
|
||||
"""
|
||||
envs: dict[str, str] = {
|
||||
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
|
||||
}
|
||||
if user_id is not None:
|
||||
integration_env = await get_integration_env_vars(user_id)
|
||||
envs.update(integration_env)
|
||||
|
||||
"""Execute *command* on the E2B sandbox via commands.run()."""
|
||||
try:
|
||||
result = await sandbox.commands.run(
|
||||
f"bash -c {shlex.quote(command)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=timeout,
|
||||
envs=envs,
|
||||
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
|
||||
)
|
||||
return BashExecResponse(
|
||||
message=f"Command executed on E2B (exit {result.exit_code})",
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Tests for BashExecTool — E2B path with token injection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .bash_exec import BashExecTool
|
||||
from .models import BashExecResponse
|
||||
|
||||
_USER = "user-bash-exec-test"
|
||||
|
||||
|
||||
def _make_tool() -> BashExecTool:
|
||||
return BashExecTool()
|
||||
|
||||
|
||||
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
|
||||
result = MagicMock()
|
||||
result.exit_code = exit_code
|
||||
result.stdout = stdout
|
||||
result.stderr = stderr
|
||||
|
||||
sandbox = MagicMock()
|
||||
sandbox.commands.run = AsyncMock(return_value=result)
|
||||
return sandbox
|
||||
|
||||
|
||||
class TestBashExecE2BTokenInjection:
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_token_injected_when_user_id_set(self):
|
||||
"""When user_id is provided, integration env vars are merged into sandbox envs."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value=env_vars),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=_USER,
|
||||
)
|
||||
|
||||
mock_get_env.assert_awaited_once_with(_USER)
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
|
||||
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
|
||||
assert isinstance(result, BashExecResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_no_token_injection_when_user_id_is_none(self):
|
||||
"""When user_id is None, get_integration_env_vars must NOT be called."""
|
||||
tool = _make_tool()
|
||||
session = make_session(user_id=_USER)
|
||||
sandbox = _make_sandbox(stdout="ok")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.bash_exec.get_integration_env_vars",
|
||||
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
|
||||
) as mock_get_env:
|
||||
result = await tool._execute_on_e2b(
|
||||
sandbox=sandbox,
|
||||
command="echo hi",
|
||||
timeout=10,
|
||||
session_id=session.session_id,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
mock_get_env.assert_not_called()
|
||||
call_kwargs = sandbox.commands.run.call_args[1]
|
||||
assert "GH_TOKEN" not in call_kwargs["envs"]
|
||||
assert isinstance(result, BashExecResponse)
|
||||
@@ -1,215 +0,0 @@
|
||||
"""Tool for prompting the user to connect a required integration.
|
||||
|
||||
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
|
||||
"authentication required"), it calls this tool to surface the credentials
|
||||
setup card in the chat — the same UI that appears when a GitHub block runs
|
||||
without configured credentials.
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
UserReadiness,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class _ProviderInfo(TypedDict):
|
||||
name: str
|
||||
types: list[str]
|
||||
# Default OAuth scopes requested when the agent doesn't specify any.
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
class _CredentialEntry(TypedDict):
|
||||
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
provider: str
|
||||
provider_name: str
|
||||
type: str
|
||||
types: list[str]
|
||||
scopes: list[str]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _is_github_oauth_configured() -> bool:
|
||||
"""Return True if GitHub OAuth env vars are set.
|
||||
|
||||
Evaluated lazily (not at import time) to avoid triggering Secrets() during
|
||||
module import, which can fail in environments where secrets are not loaded.
|
||||
"""
|
||||
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
return GITHUB_OAUTH_IS_CONFIGURED
|
||||
|
||||
|
||||
# Registry of known providers: name + supported credential types for the UI.
|
||||
# When adding a new provider, also add its env var names to
|
||||
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
|
||||
def _get_provider_info() -> dict[str, _ProviderInfo]:
|
||||
"""Build the provider registry, evaluating OAuth config lazily."""
|
||||
return {
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"types": (
|
||||
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
|
||||
),
|
||||
# Default: repo scope covers clone/push/pull for public and private repos.
|
||||
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
|
||||
"scopes": ["repo"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConnectIntegrationTool(BaseTool):
|
||||
"""Surface the credentials setup UI when an integration is not connected."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "connect_integration"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prompt the user to connect a required integration (e.g. GitHub). "
|
||||
"Call this when an external CLI or API call fails because the user "
|
||||
"has not connected the relevant account. "
|
||||
"The tool surfaces a credentials setup card in the chat so the user "
|
||||
"can authenticate without leaving the page. "
|
||||
"After the user connects the account, retry the operation. "
|
||||
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
|
||||
"automatically injected per-command in bash_exec — no manual export needed. "
|
||||
"In local bubblewrap mode network is isolated so GitHub CLI commands "
|
||||
"will still fail after connecting; inform the user of this limitation."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Integration provider slug, e.g. 'github'. "
|
||||
"Must be one of the supported providers."
|
||||
),
|
||||
"enum": list(_get_provider_info().keys()),
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Brief explanation of why the integration is needed, "
|
||||
"shown to the user in the setup card."
|
||||
),
|
||||
"maxLength": 500,
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"OAuth scopes to request. Omit to use the provider default. "
|
||||
"Add extra scopes when you need more access — e.g. for GitHub: "
|
||||
"'repo' (clone/push/pull), 'read:org' (org membership), "
|
||||
"'workflow' (GitHub Actions). "
|
||||
"Requesting only the scopes you actually need is best practice."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["provider"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
# Require auth so only authenticated users can trigger the setup card.
|
||||
# The card itself is user-agnostic (no per-user data needed), so
|
||||
# user_id is intentionally unused in _execute.
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
|
||||
session_id = session.session_id if session else None
|
||||
provider: str = (kwargs.get("provider") or "").strip().lower()
|
||||
reason: str = (kwargs.get("reason") or "").strip()[
|
||||
:500
|
||||
] # cap LLM-controlled text
|
||||
extra_scopes: list[str] = [
|
||||
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
|
||||
]
|
||||
|
||||
provider_info = _get_provider_info()
|
||||
info = provider_info.get(provider)
|
||||
if not info:
|
||||
supported = ", ".join(f"'{p}'" for p in provider_info)
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Unknown provider '{provider}'. "
|
||||
f"Supported providers: {supported}."
|
||||
),
|
||||
error="unknown_provider",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
provider_name: str = info["name"]
|
||||
supported_types: list[str] = info["types"]
|
||||
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
|
||||
default_scopes: list[str] = info["scopes"]
|
||||
seen: set[str] = set()
|
||||
scopes: list[str] = []
|
||||
for s in default_scopes + extra_scopes:
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
scopes.append(s)
|
||||
field_key = f"{provider}_credentials"
|
||||
|
||||
message_parts = [
|
||||
f"To continue, please connect your {provider_name} account.",
|
||||
]
|
||||
if reason:
|
||||
message_parts.append(reason)
|
||||
|
||||
credential_entry: _CredentialEntry = {
|
||||
"id": field_key,
|
||||
"title": f"{provider_name} Credentials",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"type": supported_types[0],
|
||||
"types": supported_types,
|
||||
"scopes": scopes,
|
||||
}
|
||||
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
type=ResponseType.SETUP_REQUIREMENTS,
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=f"connect_{provider}",
|
||||
agent_name=provider_name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
missing_credentials=missing_credentials,
|
||||
ready_to_run=False,
|
||||
),
|
||||
requirements={
|
||||
"credentials": [missing_credentials[field_key]],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for ConnectIntegrationTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
from .models import ErrorResponse, SetupRequirementsResponse
|
||||
|
||||
_TEST_USER_ID = "test-user-connect-integration"
|
||||
|
||||
|
||||
class TestConnectIntegrationTool:
|
||||
def _make_tool(self) -> ConnectIntegrationTool:
|
||||
return ConnectIntegrationTool()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unknown_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
assert "nonexistent" in result.message
|
||||
assert "github" in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_empty_provider_returns_error(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider=""
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "unknown_provider"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_provider_returns_setup_response(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.setup_info.agent_name == "GitHub"
|
||||
assert result.setup_info.agent_id == "connect_github"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_has_missing_credentials_in_readiness(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
readiness = result.setup_info.user_readiness
|
||||
assert readiness.has_all_credentials is False
|
||||
assert readiness.ready_to_run is False
|
||||
assert "github_credentials" in readiness.missing_credentials
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_github_requirements_include_credential_entry(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
creds = result.setup_info.requirements["credentials"]
|
||||
assert len(creds) == 1
|
||||
assert creds[0]["provider"] == "github"
|
||||
assert creds[0]["id"] == "github_credentials"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_reason_appears_in_message(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
reason = "Needed to create a pull request."
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert reason in result.message
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_session_id_propagated(self):
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="github"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
assert result.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_provider_case_insensitive(self):
|
||||
"""Provider slug is normalised to lowercase before lookup."""
|
||||
tool = self._make_tool()
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, provider="GitHub"
|
||||
)
|
||||
assert isinstance(result, SetupRequirementsResponse)
|
||||
|
||||
def test_tool_name(self):
|
||||
assert ConnectIntegrationTool().name == "connect_integration"
|
||||
|
||||
def test_requires_auth(self):
|
||||
assert ConnectIntegrationTool().requires_auth is True
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_unauthenticated_user_gets_need_login_response(self):
|
||||
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
|
||||
|
||||
This verifies that the requires_auth guard in BaseTool.execute() fires
|
||||
before _execute() is called, so unauthenticated callers cannot probe
|
||||
which integrations are configured.
|
||||
"""
|
||||
import json
|
||||
|
||||
tool = self._make_tool()
|
||||
# Session still needs a user_id string; the None is passed to execute()
|
||||
# to simulate an unauthenticated call.
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
result = await tool.execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
tool_call_id="test-call-id",
|
||||
provider="github",
|
||||
)
|
||||
raw = result.output
|
||||
output = json.loads(raw) if isinstance(raw, str) else raw
|
||||
assert output.get("type") == "need_login"
|
||||
assert result.success is False
|
||||
@@ -19,7 +19,8 @@ class FindAgentTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
"Discover agents from the marketplace based on capabilities and "
|
||||
"user needs, or look up a specific agent by its creator/slug ID."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -29,7 +30,7 @@ class FindAgentTool(BaseTool):
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish. Use single keywords for best results.",
|
||||
"description": "Search query describing what the user wants to accomplish, or a creator/slug ID (e.g. 'username/agent-name') for direct lookup. Use single keywords for best results.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -15,6 +15,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from .utils import is_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,7 +53,8 @@ class FindBlockTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search for available blocks by name or description. "
|
||||
"Search for available blocks by name or description, or look up a "
|
||||
"specific block by its ID. "
|
||||
"Blocks are reusable components that perform specific tasks like "
|
||||
"sending emails, making API calls, processing text, etc. "
|
||||
"IMPORTANT: Use this tool FIRST to get the block's 'id' before calling run_block. "
|
||||
@@ -68,7 +70,8 @@ class FindBlockTool(BaseTool):
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Search query to find blocks by name or description. "
|
||||
"Search query to find blocks by name or description, "
|
||||
"or a block ID (UUID) for direct lookup. "
|
||||
"Use keywords like 'email', 'http', 'text', 'ai', etc."
|
||||
),
|
||||
},
|
||||
@@ -113,11 +116,77 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
message="Please provide a search query or block ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Direct ID lookup if query looks like a UUID
|
||||
if is_uuid(query):
|
||||
block = get_block(query.lower())
|
||||
if block:
|
||||
if block.disabled:
|
||||
return NoResultsResponse(
|
||||
message=f"Block '{block.name}' (ID: {query}) is disabled and cannot be used.",
|
||||
suggestions=["Search for an alternative block by name"],
|
||||
session_id=session_id,
|
||||
)
|
||||
if (
|
||||
block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
or block.id in COPILOT_EXCLUDED_BLOCK_IDS
|
||||
):
|
||||
if block.block_type == BlockType.MCP_TOOL:
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not "
|
||||
"runnable through find_block/run_block. Use "
|
||||
"run_mcp_tool instead."
|
||||
),
|
||||
suggestions=[
|
||||
"Use run_mcp_tool to discover and run this MCP tool",
|
||||
"Search for an alternative block by name",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
return NoResultsResponse(
|
||||
message=(
|
||||
f"Block '{block.name}' (ID: {block.id}) is not available "
|
||||
"in CoPilot. It can only be used within agent graphs."
|
||||
),
|
||||
suggestions=[
|
||||
"Search for an alternative block by name",
|
||||
"Use this block in an agent graph instead",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
summary = BlockInfoSummary(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
description=(
|
||||
block.optimized_description or block.description or ""
|
||||
),
|
||||
categories=[c.value for c in block.categories],
|
||||
)
|
||||
if include_schemas:
|
||||
info = block.get_info()
|
||||
summary.input_schema = info.inputSchema
|
||||
summary.output_schema = info.outputSchema
|
||||
summary.static_output = info.staticOutput
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found block '{block.name}' by ID. "
|
||||
"To see inputs/outputs and execute it, use "
|
||||
"run_block with the block's 'id' - providing "
|
||||
"no inputs."
|
||||
),
|
||||
blocks=[summary],
|
||||
count=1,
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
|
||||
@@ -499,3 +499,123 @@ class TestFindBlockFiltering:
|
||||
assert response.blocks[0].input_schema == input_schema
|
||||
assert response.blocks[0].output_schema == output_schema
|
||||
assert response.blocks[0].static_output is True
|
||||
|
||||
|
||||
class TestFindBlockDirectLookup:
|
||||
"""Tests for direct UUID lookup in FindBlockTool."""
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_found(self):
|
||||
"""UUID query returns the block directly without search."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Test Block", BlockType.STANDARD)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
assert isinstance(response, BlockListResponse)
|
||||
assert response.count == 1
|
||||
assert response.blocks[0].id == block_id
|
||||
assert response.blocks[0].name == "Test Block"
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_not_found_falls_through(self):
|
||||
"""UUID that doesn't match any block falls through to search."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(return_value=([], 0))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_disabled_block(self):
|
||||
"""UUID matching a disabled block returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(
|
||||
block_id, "Disabled Block", BlockType.STANDARD, disabled=True
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "disabled" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_type(self):
|
||||
"""UUID matching an excluded block type returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d"
|
||||
block = make_mock_block(block_id, "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=block_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_uuid_lookup_excluded_block_id(self):
|
||||
"""UUID matching an excluded block ID returns NoResultsResponse."""
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
smart_decision_id = "3b191d9f-356f-482d-8238-ba04b6d18381"
|
||||
block = make_mock_block(
|
||||
smart_decision_id, "Smart Decision Maker", BlockType.STANDARD
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
return_value=block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
response = await tool._execute(
|
||||
user_id=_TEST_USER_ID, session=session, query=smart_decision_id
|
||||
)
|
||||
|
||||
from .models import NoResultsResponse
|
||||
|
||||
assert isinstance(response, NoResultsResponse)
|
||||
assert "not available" in response.message.lower()
|
||||
|
||||
@@ -12,7 +12,6 @@ from backend.copilot.constants import (
|
||||
COPILOT_SESSION_PREFIX,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
|
||||
from backend.data.db_accessors import review_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
@@ -198,29 +197,6 @@ class RunBlockTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Expand @@agptfile: refs in input_data with the block's input
|
||||
# schema. The generic _truncating wrapper skips opaque object
|
||||
# properties (input_data has no declared inner properties in the
|
||||
# tool schema), so file ref tokens are still intact here.
|
||||
# Using the block's schema lets us return raw text for string-typed
|
||||
# fields and parsed structures for list/dict-typed fields.
|
||||
if input_data:
|
||||
try:
|
||||
input_data = await expand_file_refs_in_args(
|
||||
input_data,
|
||||
user_id,
|
||||
session,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
except FileRefExpansionError as exc:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Failed to resolve file reference: {exc}. "
|
||||
"Ensure the file exists before referencing it."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
# Return setup requirements response with missing credentials
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Shared utilities for chat tools."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import model as library_model
|
||||
@@ -19,6 +20,26 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared UUID v4 pattern used by multiple tools for direct ID lookups.
|
||||
_UUID_V4_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_V4_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
# Matches "creator/slug" identifiers used in the marketplace
|
||||
_CREATOR_SLUG_PATTERN = re.compile(r"^[\w-]+/[\w-]+$")
|
||||
|
||||
|
||||
def is_creator_slug(text: str) -> bool:
|
||||
"""Check if text matches a 'creator/slug' marketplace identifier."""
|
||||
return bool(_CREATOR_SLUG_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def fetch_graph_from_store_slug(
|
||||
username: str,
|
||||
|
||||
@@ -10,11 +10,11 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_workspace_manager,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -218,6 +218,12 @@ def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
@@ -380,7 +386,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
@@ -530,7 +536,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
@@ -766,7 +772,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
@@ -893,7 +899,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
manager = await get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
|
||||
@@ -25,35 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
|
||||
_on_creds_changed: Callable[[str, str], None] | None = None
|
||||
|
||||
|
||||
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
|
||||
"""Register a callback invoked after any credential is created/updated/deleted.
|
||||
|
||||
The callback receives ``(user_id, provider)`` and should be idempotent.
|
||||
Only one hook can be registered at a time; calling this again replaces the
|
||||
previous hook. Intended to be called once at application startup by the
|
||||
copilot module to bust its token cache without creating an import cycle.
|
||||
"""
|
||||
global _on_creds_changed
|
||||
_on_creds_changed = hook
|
||||
|
||||
|
||||
def _bust_copilot_cache(user_id: str, provider: str) -> None:
|
||||
"""Invoke the registered hook (if any) to bust downstream token caches."""
|
||||
if _on_creds_changed is not None:
|
||||
try:
|
||||
_on_creds_changed(user_id, provider)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Credential-change hook failed for user=%s provider=%s",
|
||||
user_id,
|
||||
provider,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
class IntegrationCredentialsManager:
|
||||
"""
|
||||
Handles the lifecycle of integration credentials.
|
||||
@@ -98,11 +69,7 @@ class IntegrationCredentialsManager:
|
||||
return self._locks
|
||||
|
||||
async def create(self, user_id: str, credentials: Credentials) -> None:
|
||||
result = await self.store.add_creds(user_id, credentials)
|
||||
# Bust the copilot token cache so that the next bash_exec picks up the
|
||||
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
|
||||
_bust_copilot_cache(user_id, credentials.provider)
|
||||
return result
|
||||
return await self.store.add_creds(user_id, credentials)
|
||||
|
||||
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
||||
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
||||
@@ -189,8 +156,6 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
||||
await self.store.update_creds(user_id, fresh_credentials)
|
||||
# Bust copilot cache so the refreshed token is picked up immediately.
|
||||
_bust_copilot_cache(user_id, fresh_credentials.provider)
|
||||
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
||||
try:
|
||||
await _lock.release()
|
||||
@@ -203,17 +168,10 @@ class IntegrationCredentialsManager:
|
||||
async def update(self, user_id: str, updated: Credentials) -> None:
|
||||
async with self._locked(user_id, updated.id):
|
||||
await self.store.update_creds(user_id, updated)
|
||||
# Bust the copilot token cache so the updated credential is picked up immediately.
|
||||
_bust_copilot_cache(user_id, updated.provider)
|
||||
|
||||
async def delete(self, user_id: str, credentials_id: str) -> None:
|
||||
async with self._locked(user_id, credentials_id):
|
||||
# Read inside the lock to avoid TOCTOU — another coroutine could
|
||||
# delete the same credential between the read and the delete.
|
||||
creds = await self.store.get_creds_by_id(user_id, credentials_id)
|
||||
await self.store.delete_creds_by_id(user_id, credentials_id)
|
||||
if creds:
|
||||
_bust_copilot_cache(user_id, creds.provider)
|
||||
|
||||
# -- Locking utilities -- #
|
||||
|
||||
|
||||
@@ -275,12 +275,13 @@ async def store_media_file(
|
||||
# Process file
|
||||
elif file.startswith("data:"):
|
||||
# Data URI
|
||||
parsed_uri = parse_data_uri(file)
|
||||
if parsed_uri is None:
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Invalid data URI format. Expected data:<mime>;base64,<data>"
|
||||
)
|
||||
mime_type, b64_content = parsed_uri
|
||||
mime_type = match.group(1).strip().lower()
|
||||
b64_content = match.group(2).strip()
|
||||
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
@@ -414,70 +415,13 @@ def get_dir_size(path: Path) -> int:
|
||||
return total
|
||||
|
||||
|
||||
async def resolve_media_content(
|
||||
content: MediaFileType,
|
||||
execution_context: "ExecutionContext",
|
||||
*,
|
||||
return_format: MediaReturnFormat,
|
||||
) -> MediaFileType:
|
||||
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
|
||||
|
||||
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
|
||||
Plain text content (source code, filenames) is returned unchanged. Media
|
||||
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
|
||||
:func:`store_media_file` using *return_format*.
|
||||
|
||||
Use this when a block field is typed as ``MediaFileType`` but may contain
|
||||
either literal text or a media reference.
|
||||
"""
|
||||
if not content or not is_media_file_ref(content):
|
||||
return content
|
||||
return await store_media_file(
|
||||
content, execution_context, return_format=return_format
|
||||
)
|
||||
|
||||
|
||||
def is_media_file_ref(value: str) -> bool:
|
||||
"""Return True if *value* looks like a ``MediaFileType`` reference.
|
||||
|
||||
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
|
||||
formats accepted by :func:`store_media_file`. Plain text content
|
||||
(e.g. source code, filenames) returns False.
|
||||
|
||||
Known limitation: HTTP(S) URL detection is heuristic. Any string that
|
||||
starts with ``http://`` or ``https://`` is treated as a media URL, even
|
||||
if it appears as a URL inside source-code comments or documentation.
|
||||
Blocks that produce source code or Markdown as output may therefore
|
||||
trigger false positives. Callers that need higher precision should
|
||||
inspect the string further (e.g. verify the URL is reachable or has a
|
||||
media-friendly extension).
|
||||
|
||||
Note: this does *not* match local file paths, which are ambiguous
|
||||
(could be filenames or actual paths). Blocks that need to resolve
|
||||
local paths should check for them separately.
|
||||
"""
|
||||
return value.startswith(("data:", "workspace://", "http://", "https://"))
|
||||
|
||||
|
||||
def parse_data_uri(value: str) -> tuple[str, str] | None:
|
||||
"""Parse a ``data:<mime>;base64,<payload>`` URI.
|
||||
|
||||
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
|
||||
or ``None`` if it is not.
|
||||
"""
|
||||
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1).strip().lower(), match.group(2).strip()
|
||||
|
||||
|
||||
def get_mime_type(file: str) -> str:
|
||||
"""
|
||||
Get the MIME type of a file, whether it's a data URI, URL, or local path.
|
||||
"""
|
||||
if file.startswith("data:"):
|
||||
parsed_uri = parse_data_uri(file)
|
||||
return parsed_uri[0] if parsed_uri else "application/octet-stream"
|
||||
match = re.match(r"^data:([^;]+);base64,", file)
|
||||
return match.group(1) if match else "application/octet-stream"
|
||||
|
||||
elif file.startswith(("http://", "https://")):
|
||||
parsed_url = urlparse(file)
|
||||
|
||||
@@ -1,375 +0,0 @@
|
||||
"""Parse file content into structured Python objects based on file format.
|
||||
|
||||
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
|
||||
formats into native Python types *before* schema-driven coercion runs. This
|
||||
lets blocks with ``Any``-typed inputs receive structured data rather than raw
|
||||
strings, while blocks expecting strings get the value coerced back via
|
||||
``convert()``.
|
||||
|
||||
Supported formats:
|
||||
|
||||
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
|
||||
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
|
||||
when all lines are dicts with the same keys (tabular data), output is
|
||||
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
|
||||
otherwise returns a plain ``list`` of parsed values
|
||||
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
|
||||
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
|
||||
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
|
||||
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
|
||||
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
|
||||
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
|
||||
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
|
||||
|
||||
The **fallback contract** is enforced by :func:`parse_file_content`, not by
|
||||
individual parser functions. If any parser raises, ``parse_file_content``
|
||||
catches the exception and returns the original content unchanged (string for
|
||||
text formats, bytes for binary formats). Callers should never see an
|
||||
exception from the public API when ``strict=False``.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import tomllib
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
|
||||
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
|
||||
# unlike os.path.splitext which uses platform-native separators.
|
||||
from posixpath import splitext
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extension / MIME → format label mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXT_TO_FORMAT: dict[str, str] = {
|
||||
".json": "json",
|
||||
".jsonl": "jsonl",
|
||||
".ndjson": "jsonl",
|
||||
".csv": "csv",
|
||||
".tsv": "tsv",
|
||||
".yaml": "yaml",
|
||||
".yml": "yaml",
|
||||
".toml": "toml",
|
||||
".parquet": "parquet",
|
||||
".xlsx": "xlsx",
|
||||
}
|
||||
|
||||
MIME_TO_FORMAT: dict[str, str] = {
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "jsonl",
|
||||
"application/jsonl": "jsonl",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "tsv",
|
||||
"application/x-yaml": "yaml",
|
||||
"application/yaml": "yaml",
|
||||
"text/yaml": "yaml",
|
||||
"application/toml": "toml",
|
||||
"application/vnd.apache.parquet": "parquet",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
# Formats that require raw bytes rather than decoded text.
|
||||
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API (top-down: main functions first, helpers below)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_format_from_uri(uri: str) -> str | None:
|
||||
"""Return a format label based on URI extension or MIME fragment.
|
||||
|
||||
Returns ``None`` when the format cannot be determined — the caller should
|
||||
fall back to returning the content as a plain string.
|
||||
"""
|
||||
# 1. Check MIME fragment (workspace://abc123#application/json)
|
||||
if "#" in uri:
|
||||
_, fragment = uri.rsplit("#", 1)
|
||||
fmt = MIME_TO_FORMAT.get(fragment.lower())
|
||||
if fmt:
|
||||
return fmt
|
||||
|
||||
# 2. Check file extension from the path portion.
|
||||
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
|
||||
path = uri.split("#")[0].split("?")[0]
|
||||
_, ext = splitext(path)
|
||||
fmt = _EXT_TO_FORMAT.get(ext.lower())
|
||||
if fmt is not None:
|
||||
return fmt
|
||||
|
||||
# Legacy .xls is not supported — map it so callers can produce a
|
||||
# user-friendly error instead of returning garbled binary.
|
||||
if ext.lower() == ".xls":
|
||||
return "xls"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
|
||||
"""Parse *content* according to *fmt* and return a native Python value.
|
||||
|
||||
When *strict* is ``False`` (default), returns the original *content*
|
||||
unchanged if *fmt* is not recognised or parsing fails for any reason.
|
||||
This mode **never raises**.
|
||||
|
||||
When *strict* is ``True``, parsing errors are propagated to the caller.
|
||||
Unrecognised formats or type mismatches (e.g. text for a binary format)
|
||||
still return *content* unchanged without raising.
|
||||
"""
|
||||
if fmt == "xls":
|
||||
return (
|
||||
"[Unsupported format] Legacy .xls files are not supported. "
|
||||
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
|
||||
)
|
||||
|
||||
try:
|
||||
if fmt in BINARY_FORMATS:
|
||||
parser = _BINARY_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
# Caller gave us text for a binary format — can't parse.
|
||||
return content
|
||||
return parser(content)
|
||||
|
||||
parser = _TEXT_PARSERS.get(fmt)
|
||||
if parser is None:
|
||||
return content
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
return parser(content)
|
||||
|
||||
except PARSE_EXCEPTIONS:
|
||||
if strict:
|
||||
raise
|
||||
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
|
||||
return content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception loading helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_openpyxl_exception() -> type[Exception]:
|
||||
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
|
||||
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
|
||||
|
||||
return InvalidFileException
|
||||
|
||||
|
||||
def _load_arrow_exception() -> type[Exception]:
|
||||
"""Return pyarrow's ArrowException, raising ImportError if absent."""
|
||||
from pyarrow import ArrowException # noqa: PLC0415
|
||||
|
||||
return ArrowException
|
||||
|
||||
|
||||
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
|
||||
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
|
||||
try:
|
||||
return loader()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
# Exception types that can be raised during file content parsing.
|
||||
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
|
||||
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
|
||||
#
|
||||
# Optional-dependency exception types are loaded via a helper that raises
|
||||
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
|
||||
# This ensures mypy sees clean types and missing deps surface as real errors.
|
||||
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
|
||||
exc
|
||||
for exc in (
|
||||
json.JSONDecodeError,
|
||||
csv.Error,
|
||||
yaml.YAMLError,
|
||||
tomllib.TOMLDecodeError,
|
||||
ValueError,
|
||||
UnicodeDecodeError,
|
||||
ImportError,
|
||||
OSError,
|
||||
KeyError,
|
||||
TypeError,
|
||||
zipfile.BadZipFile,
|
||||
_optional_exc(_load_openpyxl_exception),
|
||||
# ArrowException covers ArrowIOError and ArrowCapacityError which
|
||||
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
|
||||
# already map to ValueError/TypeError but this catches the rest.
|
||||
_optional_exc(_load_arrow_exception),
|
||||
)
|
||||
if exc is not None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text-based parsers (content: str → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
|
||||
"""Parse *content* and return the result only if it is a container (list/dict).
|
||||
|
||||
Scalar values (strings, numbers, booleans, None) are discarded and the
|
||||
original *content* string is returned instead. This prevents e.g. a JSON
|
||||
file containing just ``"42"`` from silently becoming an int.
|
||||
"""
|
||||
parsed = parser(content)
|
||||
if isinstance(parsed, (list, dict)):
|
||||
return parsed
|
||||
return content
|
||||
|
||||
|
||||
def _parse_json(content: str) -> list | dict | str:
|
||||
return _parse_container(json.loads, content)
|
||||
|
||||
|
||||
def _parse_jsonl(content: str) -> Any:
|
||||
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return content
|
||||
|
||||
# When every line is a dict with the same keys, convert to table format
|
||||
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
|
||||
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
|
||||
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
|
||||
keys = list(lines[0].keys())
|
||||
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
|
||||
keys_tuple = tuple(keys)
|
||||
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
|
||||
return [keys] + [[obj[k] for k in keys] for obj in lines]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _parse_csv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter=",")
|
||||
|
||||
|
||||
def _parse_tsv(content: str) -> Any:
|
||||
return _parse_delimited(content, delimiter="\t")
|
||||
|
||||
|
||||
def _parse_delimited(content: str, *, delimiter: str) -> Any:
|
||||
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
||||
# csv.reader never yields [] — blank lines yield [""]. Filter out
|
||||
# rows where every cell is empty (i.e. truly blank lines).
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
if not rows:
|
||||
return content
|
||||
# If the declared delimiter produces only single-column rows, try
|
||||
# sniffing the actual delimiter — catches misidentified files (e.g.
|
||||
# a tab-delimited file with a .csv extension).
|
||||
if len(rows[0]) == 1:
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(content[:8192])
|
||||
if dialect.delimiter != delimiter:
|
||||
reader = csv.reader(io.StringIO(content), dialect)
|
||||
rows = [row for row in reader if _row_has_content(row)]
|
||||
except csv.Error:
|
||||
pass
|
||||
if rows and len(rows[0]) >= 2:
|
||||
return rows
|
||||
return content
|
||||
|
||||
|
||||
def _row_has_content(row: list[str]) -> bool:
|
||||
"""Return True when *row* contains at least one non-empty cell.
|
||||
|
||||
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
|
||||
This predicate filters those out consistently across the initial read
|
||||
and the sniffer-fallback re-read.
|
||||
"""
|
||||
return any(cell for cell in row)
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> list | dict | str:
|
||||
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
|
||||
# safe_load prevents code execution; for production hardening consider
|
||||
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
|
||||
if "\n---" in content or content.startswith("---\n"):
|
||||
# Multi-document YAML: only the first document is parsed; the rest
|
||||
# are silently ignored by yaml.safe_load. Warn so callers are aware.
|
||||
logger.warning(
|
||||
"Multi-document YAML detected (--- separator); "
|
||||
"only the first document will be parsed."
|
||||
)
|
||||
return _parse_container(yaml.safe_load, content)
|
||||
|
||||
|
||||
def _parse_toml(content: str) -> Any:
|
||||
parsed = tomllib.loads(content)
|
||||
# tomllib.loads always returns a dict — return it even if empty.
|
||||
return parsed
|
||||
|
||||
|
||||
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
|
||||
"json": _parse_json,
|
||||
"jsonl": _parse_jsonl,
|
||||
"csv": _parse_csv,
|
||||
"tsv": _parse_tsv,
|
||||
"yaml": _parse_yaml,
|
||||
"toml": _parse_toml,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary-based parsers (content: bytes → Any)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_parquet(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_parquet(io.BytesIO(content))
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _parse_xlsx(content: bytes) -> list[list[Any]]:
|
||||
import pandas as pd
|
||||
|
||||
# Explicitly specify openpyxl engine; the default engine varies by pandas
|
||||
# version and does not support legacy .xls (which is excluded by our format map).
|
||||
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
|
||||
return _df_to_rows(df)
|
||||
|
||||
|
||||
def _df_to_rows(df: Any) -> list[list[Any]]:
|
||||
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
|
||||
|
||||
NaN values are replaced with ``None`` so the result is JSON-serializable.
|
||||
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
|
||||
silently converts ``None`` back to ``NaN`` in float64 columns.
|
||||
"""
|
||||
header = df.columns.tolist()
|
||||
rows = [
|
||||
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
|
||||
]
|
||||
return [header] + rows
|
||||
|
||||
|
||||
def _is_nan(cell: Any) -> bool:
|
||||
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
|
||||
|
||||
``pd.isna()`` on a list/dict returns a boolean array which raises
|
||||
``ValueError`` in a boolean context. Guard with a scalar check first.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
|
||||
|
||||
|
||||
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
|
||||
"parquet": _parse_parquet,
|
||||
"xlsx": _parse_xlsx,
|
||||
}
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for file_content_parser — format inference and structured parsing."""
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util.file_content_parser import (
|
||||
BINARY_FORMATS,
|
||||
infer_format_from_uri,
|
||||
parse_file_content,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# infer_format_from_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferFormat:
|
||||
# --- extension-based ---
|
||||
|
||||
def test_json_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.json") == "json"
|
||||
|
||||
def test_jsonl_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
|
||||
|
||||
def test_ndjson_extension(self):
|
||||
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
|
||||
|
||||
def test_csv_extension(self):
|
||||
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
|
||||
|
||||
def test_tsv_extension(self):
|
||||
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
|
||||
|
||||
def test_yaml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
|
||||
|
||||
def test_yml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
|
||||
|
||||
def test_toml_extension(self):
|
||||
assert infer_format_from_uri("/home/user/config.toml") == "toml"
|
||||
|
||||
def test_parquet_extension(self):
|
||||
assert infer_format_from_uri("/data/table.parquet") == "parquet"
|
||||
|
||||
def test_xlsx_extension(self):
|
||||
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
|
||||
|
||||
def test_xls_extension_returns_xls_label(self):
|
||||
# Legacy .xls is mapped so callers can produce a helpful error.
|
||||
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert infer_format_from_uri("/data/FILE.JSON") == "json"
|
||||
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
|
||||
|
||||
def test_unicode_filename(self):
|
||||
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
|
||||
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
|
||||
|
||||
def test_unknown_extension(self):
|
||||
assert infer_format_from_uri("/home/user/readme.txt") is None
|
||||
|
||||
def test_no_extension(self):
|
||||
assert infer_format_from_uri("workspace://abc123") is None
|
||||
|
||||
# --- MIME-based ---
|
||||
|
||||
def test_mime_json(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
|
||||
|
||||
def test_mime_csv(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
|
||||
|
||||
def test_mime_tsv(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
|
||||
== "tsv"
|
||||
)
|
||||
|
||||
def test_mime_ndjson(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
|
||||
)
|
||||
|
||||
def test_mime_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
|
||||
|
||||
def test_mime_xlsx(self):
|
||||
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
assert infer_format_from_uri(uri) == "xlsx"
|
||||
|
||||
def test_mime_parquet(self):
|
||||
assert (
|
||||
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
|
||||
== "parquet"
|
||||
)
|
||||
|
||||
def test_unknown_mime(self):
|
||||
assert infer_format_from_uri("workspace://abc123#text/plain") is None
|
||||
|
||||
def test_unknown_mime_falls_through_to_extension(self):
|
||||
# Unknown MIME (text/plain) should fall through to extension-based detection.
|
||||
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
|
||||
|
||||
# --- MIME takes precedence over extension ---
|
||||
|
||||
def test_mime_overrides_extension(self):
|
||||
# .txt extension but JSON MIME → json
|
||||
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJson:
|
||||
def test_array(self):
|
||||
result = parse_file_content("[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_object(self):
|
||||
result = parse_file_content('{"key": "value"}', "json")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_nested(self):
|
||||
content = json.dumps({"rows": [[1, 2], [3, 4]]})
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == {"rows": [[1, 2], [3, 4]]}
|
||||
|
||||
def test_scalar_string_stays_as_string(self):
|
||||
result = parse_file_content('"hello"', "json")
|
||||
assert result == '"hello"' # original content, not parsed
|
||||
|
||||
def test_scalar_number_stays_as_string(self):
|
||||
result = parse_file_content("42", "json")
|
||||
assert result == "42"
|
||||
|
||||
def test_scalar_boolean_stays_as_string(self):
|
||||
result = parse_file_content("true", "json")
|
||||
assert result == "true"
|
||||
|
||||
def test_null_stays_as_string(self):
|
||||
result = parse_file_content("null", "json")
|
||||
assert result == "null"
|
||||
|
||||
def test_invalid_json_fallback(self):
|
||||
content = "not json at all"
|
||||
result = parse_file_content(content, "json")
|
||||
assert result == content
|
||||
|
||||
def test_empty_string_fallback(self):
|
||||
result = parse_file_content("", "json")
|
||||
assert result == ""
|
||||
|
||||
def test_bytes_input_decoded(self):
|
||||
result = parse_file_content(b"[1, 2, 3]", "json")
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — JSONL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseJsonl:
|
||||
def test_tabular_uniform_dicts_to_table_format(self):
|
||||
"""JSONL with uniform dict keys → table format (header + rows),
|
||||
consistent with CSV/TSV/Parquet/Excel output."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", "yellow"],
|
||||
["cherry", "red"],
|
||||
]
|
||||
|
||||
def test_tabular_single_key_dicts(self):
|
||||
"""JSONL with single-key uniform dicts → table format."""
|
||||
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2], [3]]
|
||||
|
||||
def test_tabular_blank_lines_skipped(self):
|
||||
content = '{"a": 1}\n\n{"a": 2}\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [["a"], [1], [2]]
|
||||
|
||||
def test_heterogeneous_dicts_stay_as_list(self):
|
||||
"""JSONL with different keys across objects → list of dicts (no table)."""
|
||||
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
|
||||
|
||||
def test_partially_overlapping_keys_stay_as_list(self):
|
||||
"""JSONL dicts with partially overlapping keys → list of dicts."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
{"name": "apple", "color": "red"},
|
||||
{"name": "banana", "size": "medium"},
|
||||
]
|
||||
|
||||
def test_mixed_types_stay_as_list(self):
|
||||
"""JSONL with non-dict lines → list of parsed values (no table)."""
|
||||
content = '1\n"hello"\n[1,2]\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [1, "hello", [1, 2]]
|
||||
|
||||
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
|
||||
"""JSONL mixing dicts and non-dicts → list of parsed values."""
|
||||
content = '{"a": 1}\n42\n{"b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1}, 42, {"b": 2}]
|
||||
|
||||
def test_tabular_preserves_key_order(self):
|
||||
"""Table header should follow the key order of the first object."""
|
||||
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result[0] == ["z", "a"] # order from first object
|
||||
assert result[1] == [1, 2]
|
||||
assert result[2] == [3, 4]
|
||||
|
||||
def test_single_dict_stays_as_list(self):
|
||||
"""Single-line JSONL with one dict → [dict], NOT a table.
|
||||
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
|
||||
content = '{"a": 1, "b": 2}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [{"a": 1, "b": 2}]
|
||||
|
||||
def test_tabular_with_none_values(self):
|
||||
"""Uniform keys but some null values → table with None cells."""
|
||||
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == [
|
||||
["name", "color"],
|
||||
["apple", "red"],
|
||||
["banana", None],
|
||||
]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "jsonl")
|
||||
assert result == ""
|
||||
|
||||
def test_all_blank_lines_fallback(self):
|
||||
result = parse_file_content("\n\n\n", "jsonl")
|
||||
assert result == "\n\n\n"
|
||||
|
||||
def test_invalid_line_fallback(self):
|
||||
content = '{"a": 1}\nnot json\n'
|
||||
result = parse_file_content(content, "jsonl")
|
||||
assert result == content # fallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseCsv:
|
||||
def test_basic(self):
|
||||
content = "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_quoted_fields(self):
|
||||
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result[1] == ["Alice", "Loves, commas"]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
# Only 1 column — not tabular enough.
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
def test_empty_rows_skipped(self):
|
||||
content = "A,B\n\n1,2\n\n3,4"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
|
||||
|
||||
def test_empty_file_fallback(self):
|
||||
result = parse_file_content("", "csv")
|
||||
assert result == ""
|
||||
|
||||
def test_utf8_bom(self):
|
||||
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
|
||||
bom = "\ufeff"
|
||||
content = bom + "Name,Score\nAlice,90\nBob,85"
|
||||
result = parse_file_content(content, "csv")
|
||||
# The BOM may be part of the first header cell; ensure rows are still parsed.
|
||||
assert len(result) == 3
|
||||
assert result[1] == ["Alice", "90"]
|
||||
assert result[2] == ["Bob", "85"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseTsv:
|
||||
def test_basic(self):
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_single_column_fallback(self):
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "tsv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — YAML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseYaml:
|
||||
def test_list(self):
|
||||
content = "- apple\n- banana\n- cherry"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == ["apple", "banana", "cherry"]
|
||||
|
||||
def test_dict(self):
|
||||
content = "name: Alice\nage: 30"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_nested(self):
|
||||
content = "users:\n - name: Alice\n - name: Bob"
|
||||
result = parse_file_content(content, "yaml")
|
||||
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
|
||||
|
||||
def test_scalar_stays_as_string(self):
|
||||
result = parse_file_content("hello world", "yaml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_invalid_yaml_fallback(self):
|
||||
content = ":\n :\n invalid: - -"
|
||||
result = parse_file_content(content, "yaml")
|
||||
# Malformed YAML should fall back to the original string, not raise.
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — TOML
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseToml:
|
||||
def test_basic(self):
|
||||
content = '[server]\nhost = "localhost"\nport = 8080'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"server": {"host": "localhost", "port": 8080}}
|
||||
|
||||
def test_flat(self):
|
||||
content = 'name = "test"\ncount = 42'
|
||||
result = parse_file_content(content, "toml")
|
||||
assert result == {"name": "test", "count": 42}
|
||||
|
||||
def test_empty_string_returns_empty_dict(self):
|
||||
result = parse_file_content("", "toml")
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_toml_fallback(self):
|
||||
result = parse_file_content("not = [valid toml", "toml")
|
||||
assert result == "not = [valid toml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Parquet (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
|
||||
_has_pyarrow = True
|
||||
except ImportError:
|
||||
_has_pyarrow = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestParseParquet:
|
||||
@pytest.fixture
|
||||
def parquet_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, parquet_bytes: bytes):
|
||||
result = parse_file_content(parquet_bytes, "parquet")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
# Parquet is binary — string input can't be parsed.
|
||||
result = parse_file_content("not parquet", "parquet")
|
||||
assert result == "not parquet"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not parquet bytes", "parquet")
|
||||
assert result == b"not parquet bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "parquet")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in Parquet must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
# Row with NaN in float col → None
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[2][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]:
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — Excel (binary)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseExcel:
|
||||
@pytest.fixture
|
||||
def xlsx_bytes(self) -> bytes:
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
|
||||
return buf.getvalue()
|
||||
|
||||
def test_basic(self, xlsx_bytes: bytes):
|
||||
result = parse_file_content(xlsx_bytes, "xlsx")
|
||||
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
|
||||
|
||||
def test_string_input_fallback(self):
|
||||
result = parse_file_content("not xlsx", "xlsx")
|
||||
assert result == "not xlsx"
|
||||
|
||||
def test_invalid_bytes_fallback(self):
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
def test_empty_bytes_fallback(self):
|
||||
"""Empty binary input should return the empty bytes, not crash."""
|
||||
result = parse_file_content(b"", "xlsx")
|
||||
assert result == b""
|
||||
|
||||
def test_nan_replaced_with_none(self):
|
||||
"""NaN values in float columns must become None for JSON serializability."""
|
||||
import math
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
|
||||
buf = io.BytesIO()
|
||||
df.to_excel(buf, index=False) # type: ignore[arg-type]
|
||||
result = parse_file_content(buf.getvalue(), "xlsx")
|
||||
# Row with NaN in float col → None, not float('nan')
|
||||
assert result[2][0] is None # float NaN → None
|
||||
assert result[3][1] is None # str None → None
|
||||
# Ensure no NaN leaks
|
||||
for row in result[1:]: # skip header
|
||||
for cell in row:
|
||||
if isinstance(cell, float):
|
||||
assert not math.isnan(cell), f"NaN leaked: {row}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_file_content — unknown format / fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallback:
|
||||
def test_unknown_format_returns_content(self):
|
||||
result = parse_file_content("hello world", "xml")
|
||||
assert result == "hello world"
|
||||
|
||||
def test_none_format_returns_content(self):
|
||||
# Shouldn't normally be called with unrecognised format, but must not crash.
|
||||
result = parse_file_content("hello", "unknown_format")
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BINARY_FORMATS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryFormats:
|
||||
def test_parquet_is_binary(self):
|
||||
assert "parquet" in BINARY_FORMATS
|
||||
|
||||
def test_xlsx_is_binary(self):
|
||||
assert "xlsx" in BINARY_FORMATS
|
||||
|
||||
def test_text_formats_not_binary(self):
|
||||
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
|
||||
assert fmt not in BINARY_FORMATS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MIME mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMimeMapping:
|
||||
def test_application_yaml(self):
|
||||
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSV sniffer fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCsvSnifferFallback:
|
||||
def test_tab_delimited_with_csv_format(self):
|
||||
"""Tab-delimited content parsed as csv should use sniffer fallback."""
|
||||
content = "Name\tScore\nAlice\t90\nBob\t85"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
|
||||
|
||||
def test_sniffer_failure_returns_content(self):
|
||||
"""When sniffer fails, single-column falls back to raw content."""
|
||||
content = "Name\nAlice\nBob"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenpyxlInvalidFile fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenpyxlFallback:
|
||||
def test_invalid_xlsx_non_strict(self):
|
||||
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
|
||||
result = parse_file_content(b"not xlsx bytes", "xlsx")
|
||||
assert result == b"not xlsx bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Header-only CSV
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaderOnlyCsv:
|
||||
def test_header_only_csv_returns_header_row(self):
|
||||
"""CSV with only a header row (no data rows) should return [[header]]."""
|
||||
content = "Name,Score"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
def test_header_only_csv_with_trailing_newline(self):
|
||||
content = "Name,Score\n"
|
||||
result = parse_file_content(content, "csv")
|
||||
assert result == [["Name", "Score"]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Binary format + line range (line range ignored for binary formats)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
|
||||
class TestBinaryFormatLineRange:
|
||||
def test_parquet_ignores_line_range(self):
|
||||
"""Binary formats should parse the full file regardless of line range.
|
||||
|
||||
Line ranges are meaningless for binary formats (parquet/xlsx) — the
|
||||
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
|
||||
should return the complete structured data.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
|
||||
buf = io.BytesIO()
|
||||
df.to_parquet(buf, index=False)
|
||||
# parse_file_content itself doesn't take a line range — this tests
|
||||
# that the full content is parsed even though the bytes could have
|
||||
# been truncated upstream (it's not, by design).
|
||||
result = parse_file_content(buf.getvalue(), "parquet")
|
||||
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy .xls UX
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestXlsFallback:
|
||||
def test_xls_returns_helpful_error_string(self):
|
||||
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
|
||||
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
assert "not supported" in result.lower()
|
||||
|
||||
def test_xls_with_string_content(self):
|
||||
result = parse_file_content("some text", "xls")
|
||||
assert isinstance(result, str)
|
||||
assert ".xlsx" in result
|
||||
@@ -8,12 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.util.file import (
|
||||
is_media_file_ref,
|
||||
parse_data_uri,
|
||||
resolve_media_content,
|
||||
store_media_file,
|
||||
)
|
||||
from backend.util.file import store_media_file
|
||||
from backend.util.type import MediaFileType
|
||||
|
||||
|
||||
@@ -349,162 +344,3 @@ class TestFileCloudIntegration:
|
||||
execution_context=make_test_context(graph_exec_id=graph_exec_id),
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_media_file_ref
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsMediaFileRef:
|
||||
def test_data_uri(self):
|
||||
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
|
||||
|
||||
def test_workspace_uri(self):
|
||||
assert is_media_file_ref("workspace://abc123") is True
|
||||
|
||||
def test_workspace_uri_with_mime(self):
|
||||
assert is_media_file_ref("workspace://abc123#image/png") is True
|
||||
|
||||
def test_http_url(self):
|
||||
assert is_media_file_ref("http://example.com/image.png") is True
|
||||
|
||||
def test_https_url(self):
|
||||
assert is_media_file_ref("https://example.com/image.png") is True
|
||||
|
||||
def test_plain_text(self):
|
||||
assert is_media_file_ref("print('hello')") is False
|
||||
|
||||
def test_local_path(self):
|
||||
assert is_media_file_ref("/tmp/file.txt") is False
|
||||
|
||||
def test_empty_string(self):
|
||||
assert is_media_file_ref("") is False
|
||||
|
||||
def test_filename(self):
|
||||
assert is_media_file_ref("image.png") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_data_uri
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseDataUri:
|
||||
def test_valid_png(self):
|
||||
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
|
||||
assert result is not None
|
||||
mime, payload = result
|
||||
assert mime == "image/png"
|
||||
assert payload == "iVBORw0KGg=="
|
||||
|
||||
def test_valid_text(self):
|
||||
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
|
||||
assert result is not None
|
||||
assert result[0] == "text/plain"
|
||||
assert result[1] == "SGVsbG8="
|
||||
|
||||
def test_mime_case_normalized(self):
|
||||
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
|
||||
assert result is not None
|
||||
assert result[0] == "image/png"
|
||||
|
||||
def test_not_data_uri(self):
|
||||
assert parse_data_uri("workspace://abc123") is None
|
||||
|
||||
def test_plain_text(self):
|
||||
assert parse_data_uri("hello world") is None
|
||||
|
||||
def test_missing_base64(self):
|
||||
assert parse_data_uri("data:image/png;utf-8,abc") is None
|
||||
|
||||
def test_empty_payload(self):
|
||||
result = parse_data_uri("data:image/png;base64,")
|
||||
assert result is not None
|
||||
assert result[1] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_media_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveMediaContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_passthrough(self):
|
||||
"""Plain text content (not a media ref) passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("print('hello')"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_passthrough(self):
|
||||
"""Empty string passes through unchanged."""
|
||||
ctx = make_test_context()
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(""),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_ref_delegates_to_store(self):
|
||||
"""Media references are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("workspace://img123"),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_uri_delegates_to_store(self):
|
||||
"""Data URIs are also resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
data_uri = "data:image/png;base64,iVBORw0KGg=="
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType(data_uri)),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType(data_uri),
|
||||
ctx,
|
||||
return_format="for_external_api",
|
||||
)
|
||||
assert result == data_uri
|
||||
mock_store.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_url_delegates_to_store(self):
|
||||
"""HTTPS URLs are resolved via store_media_file."""
|
||||
ctx = make_test_context()
|
||||
with patch(
|
||||
"backend.util.file.store_media_file",
|
||||
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
|
||||
) as mock_store:
|
||||
result = await resolve_media_content(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert result == "data:image/png;base64,abc"
|
||||
mock_store.assert_called_once_with(
|
||||
MediaFileType("https://example.com/image.png"),
|
||||
ctx,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
89
autogpt_platform/backend/poetry.lock
generated
89
autogpt_platform/backend/poetry.lock
generated
@@ -1360,18 +1360,6 @@ files = [
|
||||
dnspython = ">=2.0.0"
|
||||
idna = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "et-xmlfile"
|
||||
version = "2.0.0"
|
||||
description = "An implementation of lxml.xmlfile for the standard library"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
|
||||
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exa-py"
|
||||
version = "1.16.1"
|
||||
@@ -4240,21 +4228,6 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||
realtime = ["websockets (>=13,<16)"]
|
||||
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "openpyxl"
|
||||
version = "3.1.5"
|
||||
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
|
||||
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
et-xmlfile = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.39.1"
|
||||
@@ -5457,66 +5430,6 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "23.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
|
||||
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
|
||||
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
|
||||
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
|
||||
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
|
||||
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
|
||||
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.2"
|
||||
@@ -8969,4 +8882,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"
|
||||
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
|
||||
|
||||
@@ -92,8 +92,6 @@ gravitas-md2gdocs = "^0.1.0"
|
||||
posthog = "^7.6.0"
|
||||
fpdf2 = "^2.8.6"
|
||||
langsmith = "^0.7.7"
|
||||
openpyxl = "^3.1.5"
|
||||
pyarrow = "^23.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { ExclamationMarkIcon } from "@phosphor-icons/react";
|
||||
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useState } from "react";
|
||||
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
|
||||
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
@@ -130,8 +129,6 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
case "tool-search_docs":
|
||||
case "tool-get_doc_page":
|
||||
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-connect_integration":
|
||||
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
|
||||
case "tool-run_block":
|
||||
case "tool-continue_run_block":
|
||||
return <RunBlockTool key={key} part={part as ToolUIPart} />;
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useState } from "react";
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
|
||||
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
|
||||
|
||||
type Props = {
|
||||
part: ToolUIPart;
|
||||
};
|
||||
|
||||
function parseJson(raw: unknown): unknown {
|
||||
if (typeof raw === "string") {
|
||||
try {
|
||||
return JSON.parse(raw);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return raw;
|
||||
}
|
||||
|
||||
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
|
||||
return parsed as SetupRequirementsResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseError(raw: unknown): string | null {
|
||||
const parsed = parseJson(raw);
|
||||
if (parsed && typeof parsed === "object" && "message" in parsed) {
|
||||
return String((parsed as { message: unknown }).message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function ConnectIntegrationTool({ part }: Props) {
|
||||
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
|
||||
const [isDismissed, setIsDismissed] = useState(false);
|
||||
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError = part.state === "output-error";
|
||||
|
||||
const output =
|
||||
part.state === "output-available"
|
||||
? parseOutput((part as { output?: unknown }).output)
|
||||
: null;
|
||||
|
||||
const errorMessage = isError
|
||||
? (parseError((part as { output?: unknown }).output) ??
|
||||
"Failed to connect integration")
|
||||
: null;
|
||||
|
||||
const rawProvider =
|
||||
(part as { input?: { provider?: string } }).input?.provider ?? "";
|
||||
const providerName =
|
||||
output?.setup_info?.agent_name ??
|
||||
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
|
||||
// prevent runaway text in the DOM.
|
||||
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
|
||||
|
||||
const label = isStreaming
|
||||
? `Connecting ${providerName}…`
|
||||
: isError
|
||||
? `Failed to connect ${providerName}`
|
||||
: output
|
||||
? `Connect ${output.setup_info?.agent_name ?? providerName}`
|
||||
: `Connect ${providerName}`;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<MorphingTextAnimation
|
||||
text={label}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isError && errorMessage && (
|
||||
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
|
||||
)}
|
||||
|
||||
{output && (
|
||||
<div className="mt-2">
|
||||
{isDismissed ? (
|
||||
<ContentMessage>Connected. Continuing…</ContentMessage>
|
||||
) : (
|
||||
<SetupRequirementsCard
|
||||
output={output}
|
||||
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
|
||||
retryInstruction="I've connected my account. Please continue."
|
||||
onComplete={() => setIsDismissed(true)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -23,16 +23,12 @@ interface Props {
|
||||
/** Override the label shown above the credentials section.
|
||||
* Defaults to "Credentials". */
|
||||
credentialsLabel?: string;
|
||||
/** Called after Proceed is clicked so the parent can persist the dismissed state
|
||||
* across remounts (avoids re-enabling the Proceed button on remount). */
|
||||
onComplete?: () => void;
|
||||
}
|
||||
|
||||
export function SetupRequirementsCard({
|
||||
output,
|
||||
retryInstruction,
|
||||
credentialsLabel,
|
||||
onComplete,
|
||||
}: Props) {
|
||||
const { onSend } = useCopilotChatActions();
|
||||
|
||||
@@ -72,17 +68,13 @@ export function SetupRequirementsCard({
|
||||
return v !== undefined && v !== null && v !== "";
|
||||
});
|
||||
|
||||
if (hasSent) {
|
||||
return <ContentMessage>Connected. Continuing…</ContentMessage>;
|
||||
}
|
||||
|
||||
const canRun =
|
||||
!hasSent &&
|
||||
(!needsCredentials || isAllCredentialsComplete) &&
|
||||
(!needsInputs || isAllInputsComplete);
|
||||
|
||||
function handleRun() {
|
||||
setHasSent(true);
|
||||
onComplete?.();
|
||||
|
||||
const parts: string[] = [];
|
||||
if (needsCredentials) {
|
||||
|
||||
@@ -125,9 +125,9 @@ export function useCredentialsInput({
|
||||
if (hasAttemptedAutoSelect.current) return;
|
||||
hasAttemptedAutoSelect.current = true;
|
||||
|
||||
// Auto-select only when there is exactly one saved credential.
|
||||
// With multiple options the user must choose — regardless of optional/required.
|
||||
if (savedCreds.length > 1) return;
|
||||
// Auto-select if exactly one credential matches.
|
||||
// For optional fields with multiple options, let the user choose.
|
||||
if (isOptional && savedCreds.length > 1) return;
|
||||
|
||||
const cred = savedCreds[0];
|
||||
onSelectCredential({
|
||||
|
||||
Reference in New Issue
Block a user