Compare commits

...

25 Commits

Author SHA1 Message Date
Zamil Majdy
88eaab2baa Merge remote-tracking branch 'origin/dev' into feat/github-cli-copilot 2026-03-17 06:17:03 +07:00
Zamil Majdy
9a41312769 feat(backend/copilot): parse @@agptfile bare refs by file extension (#12392)
The `@@agptfile:` expansion system previously used content-sniffing
(trying
`json.loads` then `csv.Sniffer`) to decide whether to parse file content
as
structured data. This was fragile — a file containing just `"42"` would
be
parsed as an integer, and the heuristics could misfire on ambiguous
content.

This PR replaces content-sniffing with **extension/MIME-based format
detection**.
When the file has a well-known extension (`.json`, `.csv`, etc.) or MIME
type
fragment (`workspace://id#application/json`), the content is parsed
accordingly.
Unknown formats or parse failures always fall back to plain string — no
surprises.

> [!NOTE]
> This PR builds on the `@@agptfile:` file reference protocol introduced
in #12332 and the structured data auto-parsing added in #12390.
>
> **What is `@@agptfile:`?**
> It is a special URI prefix (e.g. `@@agptfile:workspace:///report.csv`)
that the CoPilot SDK expands inline before sending tool arguments to
blocks. This lets the AI reference workspace files by name, and the SDK
automatically reads and injects the file content. See #12332 for the
full design.

### Changes 🏗️

**New utility: `backend/util/file_content_parser.py`**
- `infer_format(uri)` — determines format from file extension or MIME
fragment
- `parse_file_content(content, fmt)` — parses content, never raises
- Supported text formats: JSON, JSONL/NDJSON, CSV, TSV, YAML, TOML
- Supported binary formats: Parquet (via pyarrow), Excel/XLSX (via
openpyxl)
- JSON scalars (strings, numbers, booleans, null) stay as strings — only
  containers (arrays, objects) are promoted
- CSV/TSV require ≥1 row and ≥2 columns to qualify as tabular data
- Added `openpyxl` dependency for Excel reading via pandas
- Case-insensitive MIME fragment matching per RFC 2045
- Shared `PARSE_EXCEPTIONS` constant to avoid duplication between
modules

**Updated `expand_file_refs_in_args` in `file_ref.py`**
- Bare refs now use `infer_format` + `parse_file_content` instead of the
  old `_try_parse_structured` content-sniffing function
- Binary formats (parquet, xlsx) read raw bytes via `read_file_bytes`
- Embedded refs (text around `@@agptfile:`) still produce plain strings
- **Size guards**: Workspace and sandbox file reads now enforce a 10 MB
limit
  (matching the existing local file limit) to prevent OOM on large files

**Updated `blocks/github/commits.py`**
- Consolidated `_create_blob` and `_create_binary_blob` into a single
function
  with an `encoding` parameter

**Updated copilot system prompt**
- Documents the extension-based structured data parsing and supported
formats

**66 new tests** in `file_content_parser_test.py` covering:
- Format inference (extension, MIME, case-insensitive, precedence)
- All 8 format parsers (happy path + edge cases + fallbacks)
- Binary format handling (string input fallback, invalid bytes fallback)
- Unknown format passthrough

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] All 66 file_content_parser_test.py tests pass
  - [x] All 31 file_ref_test.py tests pass
  - [x] All 13 file_ref_integration_test.py tests pass
  - [x] `poetry run format` passes clean (including pyright)
2026-03-16 22:31:21 +00:00
Zamil Majdy
4b0a445635 fix(copilot): remove implicit gh auth setup-git from sandbox creation
Remove the automatic GitHub credential helper configuration that ran on
every E2B sandbox connect/reconnect. This addressed a review concern
about implicitly giving AutoPilot full GitHub access without user
awareness or opt-in.

The bash_exec tool already injects GH_TOKEN/GITHUB_TOKEN per-command
for users who have connected their account via connect_integration,
which is the explicit opt-in path.
2026-03-17 00:36:51 +07:00
Ubbe
048fb06b0a feat(frontend): add "Jump Back In" button to Library page (#12387)
Adds a "Jump Back In" CTA at the top of the Library page to encourage
users to quickly rerun their most recently successful agent.

Closes SECRT-1536

### Changes 🏗️

- New `JumpBackIn` component with `useJumpBackIn` hook at
`library/components/JumpBackIn/`
- Fetches first page of library agents sorted by `updatedAt`
- Finds the first agent with a `COMPLETED` execution in
`recent_executions`
- Shows banner with agent name + "Jump Back In" button linking to
`/library/agents/{id}`
- Returns `null` (hidden) when loading or when no agent with a
successful run exists

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] `pnpm format`, `pnpm lint`, `pnpm types` all pass
- [x] Verified banner is hidden when no successful runs exist (edge
case)
- [x] Verified library page renders correctly with no visual regressions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 21:35:03 +08:00
Zamil Majdy
36312d2c6e fix(backend/copilot): bust cache on OAuth refresh + persist dismissed state
- creds_manager: call _bust_copilot_cache after refresh_if_needed
  persists the refreshed token so the copilot cache doesn't serve
  a stale access token after silent refresh
- ConnectIntegrationTool: lift isDismissed state to parent so
  SetupRequirementsCard remounts don't re-enable the Proceed button;
  onComplete callback propagates the dismissed signal up
2026-03-16 17:10:18 +07:00
Zamil Majdy
d6d3b8d710 fix(copilot): address coderabbitai major issues — scope token description to E2B, guard cache-bust hook
- connect_integration.py: clarify that GH_TOKEN is injected per-command in
  E2B/cloud only; note that bubblewrap isolates network so retry won't work
- creds_manager._bust_copilot_cache: wrap _on_creds_changed in try/except so
  a failing hook doesn't turn successful create/update/delete into a 500
2026-03-16 15:52:40 +07:00
Zamil Majdy
17d8d0bf05 fix(backend/copilot): run gh auth setup-git once on sandbox connect/reconnect
Move git credential helper setup out of bash_exec (where it ran on
every command) and into _setup_e2b so it runs exactly once per sandbox
connect or reconnect. Non-fatal: logged at debug level on failure.
2026-03-16 15:45:18 +07:00
Zamil Majdy
5a2ab65f41 fix(backend/copilot): run gh auth setup-git once per sandbox session
Use grep to skip re-running if the credential helper is already
configured in ~/.gitconfig — only pays the cost on first command.
Agent can still call it manually if GH_TOKEN changes mid-session.
2026-03-16 15:42:54 +07:00
Zamil Majdy
81a318de3e feat(backend/copilot): improve GitHub OAuth UX and git auth
- Dynamic OAuth scopes: connect_integration tool now accepts a `scopes`
  param so the agent can request exactly the access it needs (e.g.
  `["repo", "read:org"]`); GitHub defaults to `["repo"]` so git
  push/pull works out of the box instead of public-data-only
- Lazy git auth: prepend `gh auth setup-git` on every E2B bash_exec
  when GH_TOKEN is present — git HTTPS clone/push/pull now work
  automatically without the agent needing to set this up manually
- Prefer broadest-scoped OAuth2 credential: sort repo-scoped tokens
  first so a stale public-data token is never picked over a full one
- Collapse SetupRequirementsCard to "Connected. Continuing…" after
  Proceed is clicked instead of leaving the full card visible
- Fix credential auto-select: don't silently pick the first token when
  multiple credentials exist — let the user choose via the dropdown
2026-03-16 15:26:14 +07:00
Zamil Majdy
62c8e8634b fix(copilot): patch _manager singleton directly in tests instead of class constructor
The module-level _manager singleton is created at import time, so patching
IntegrationCredentialsManager after import has no effect. Patch the _manager
attribute directly so get_provider_token uses the mock.
2026-03-16 06:32:36 +07:00
Zamil Majdy
b91c959cd9 fix(copilot): address remaining review findings
- creds_manager: fix TOCTOU in delete() — move get_creds_by_id inside the lock
- creds_manager: replace lazy import in _bust_copilot_cache with a
  register_creds_changed_hook() callback so creds_manager has no runtime
  dependency on the copilot module
- integration_creds: register invalidate_user_provider_cache at module import
  via register_creds_changed_hook() — eliminates the circular-import risk
- integration_creds: add module-level _manager singleton (avoids re-instantiating
  IntegrationCredentialsManager on every cache miss)
- integration_creds: document TTLCache asyncio-only thread-safety assumption
- connect_integration: defer GITHUB_OAUTH_IS_CONFIGURED evaluation to runtime
  with an lru_cache'd helper; importing the module no longer triggers Secrets()
- connect_integration: type missing_credentials dict with _CredentialEntry TypedDict
- connect_integration: cap reason field at 500 chars; add maxLength to JSON schema
- bash_exec: use 'user_id is not None' instead of truthy check
- connect_integration_test: add test for unauthenticated caller (requires_auth guard)
- bash_exec_test: add E2B path tests — token injected when user_id set, skipped
  when user_id is None
- ConnectIntegrationTool.tsx: sanitize LLM-controlled providerName fallback
  (trim + slice to 64 chars)
2026-03-16 06:21:44 +07:00
Zamil Majdy
5b95a2a1ef refactor(copilot): strongly type _PROVIDER_INFO with TypedDict
Replace dict[str, Any] with a _ProviderInfo TypedDict for provider
metadata entries, eliminating key/type drift as new providers are added.
2026-03-16 06:04:02 +07:00
Zamil Majdy
9c2a601167 refactor(copilot): simplify cache with cachetools.TTLCache, fix prompt wording
- Replace manual dict+sentinel cache with two TTLCache instances:
  _token_cache (5min TTL) and _null_cache (60s TTL)
- Remove _cache_set helper and _NO_TOKEN sentinel — TTLCache handles
  expiry and LRU eviction natively
- Update tests to use _token_cache/_null_cache directly; add TTL constant test
- Change _E2B_TOOL_NOTES from "GH_TOKEN is set" to "gh is pre-authenticated"
  so the AI doesn't attempt to read the env var directly
2026-03-16 00:16:26 +07:00
Zamil Majdy
b98e37bf23 refactor(copilot): DRY cache-bust helper, fast eviction test, unified JSON parse
Backend:
- Extract _bust_copilot_cache() in creds_manager.py; create/update/delete
  now each call it once instead of repeating the try/except ImportError block
- test_evicts_oldest_when_full: patch _CACHE_MAX_SIZE to 3 to avoid
  allocating 10 000 entries in CI; remove now-unused _CACHE_MAX_SIZE import

Frontend:
- Extract parseJson() helper shared by parseOutput and parseError in
  ConnectIntegrationTool.tsx, eliminating duplicated try/catch logic
2026-03-16 00:01:10 +07:00
Zamil Majdy
fec8924361 fix(copilot): bust token cache on update/delete, tighten except clause
- creds_manager.create/update/delete now all call invalidate_user_provider_cache
  after mutating credentials, so the next bash_exec always picks up the
  current state without waiting for TTL to expire
- Change broad `except Exception` to `except ImportError` in all three methods
  so real bugs inside invalidate_user_provider_cache are not silently swallowed
- delete() reads the provider before deletion so we know which cache key to evict
- Add tests for invalidate_user_provider_cache: removes sentinel/token entry,
  no-op when key absent, only removes the targeted key
2026-03-15 23:57:11 +07:00
Zamil Majdy
712aee7302 fix(copilot): warn on stale OAuth token fallback, document per-process cache
- Log at WARNING (not DEBUG) when OAuth refresh fails and we fall back to
  a potentially stale token, so operators can diagnose repeated auth failures
- Add multi-worker note to module docstring: _token_cache is process-local;
  each replica maintains its own cache (acceptable for current goal, but a
  shared cache would be needed for cross-replica efficiency)
2026-03-15 23:53:10 +07:00
Zamil Majdy
bef292033e fix(copilot): render error state in ConnectIntegrationTool
When part.state is 'output-error', show the error message from the
backend (ErrorResponse.message) in red text below the status line.
Without this, errors from unknown/unsupported providers were silently
discarded, leaving the user without any feedback.
2026-03-15 23:51:09 +07:00
Zamil Majdy
ec6974e3b8 fix(copilot): invalidate null cache on credential creation
When a user connects an integration, IntegrationCredentialsManager.create()
now calls invalidate_user_provider_cache() to remove any stale _NO_TOKEN
sentinel from the TTL cache. Without this, the first retry after connecting
would still return None for up to _NULL_CACHE_TTL (60 s).

The import is done lazily inside create() to avoid a circular import
between integrations.creds_manager and copilot.integration_creds.
2026-03-15 23:50:34 +07:00
Zamil Majdy
2ef5e2fe77 feat(copilot): bounded TTL cache with sentinel for integration creds
- Replace empty-string sentinel with explicit _NO_TOKEN = object() to
  avoid ambiguity with zero-length tokens
- Bound _token_cache to _CACHE_MAX_SIZE=10_000 entries; _cache_set()
  evicts oldest insertion-order entry when full
- Cache "not connected" results with _NULL_CACHE_TTL=60s (vs 300s for
  found tokens) to avoid a DB hit on every E2B bash_exec for users who
  haven't connected yet, while still picking up a new connection quickly
- Add integration_creds_test.py covering all cache paths, sentinel,
  eviction, OAuth2 preferred/fallback, DB exception, and env var injection
2026-03-15 23:46:05 +07:00
Zamil Majdy
0a8c7221ce fix(copilot): address all review findings
- prompting: rename _SDK_TOOL_NOTES → _E2B_TOOL_NOTES; pass it only to
  _get_cloud_sandbox_supplement() via new extra_notes param — local
  (bubblewrap) mode uses --unshare-net so gh CLI cannot reach GitHub
- integration_creds: cache None results with 60 s TTL (_NULL_CACHE_TTL)
  to avoid a DB hit on every E2B bash_exec for users without GitHub creds;
  found tokens still cached for 5 min (_TOKEN_CACHE_TTL)
- connect_integration: add cross-reference comment to PROVIDER_ENV_VARS
- ConnectIntegrationTool: use provider-specific credentialsLabel
  (e.g. "GitHub credentials" instead of "Integration credentials")
2026-03-15 23:35:25 +07:00
Zamil Majdy
840d1de636 refactor(copilot): move token injection to bash_exec + add integration_creds module
- Extract integration token lookup into backend/copilot/integration_creds.py:
  * Generic get_provider_token(user_id, provider) with 5-min TTL cache
  * get_integration_env_vars(user_id) loops over PROVIDER_ENV_VARS registry
  * Adding a new provider only requires a one-line PROVIDER_ENV_VARS entry

- Inject tokens lazily in bash_exec._execute_on_e2b (E2B has internet access;
  bubblewrap uses --unshare-net so gh CLI cannot reach GitHub regardless)

- Remove eager per-turn GH_TOKEN injection from sdk/service.py (wrong layer:
  bubblewrap is network-isolated, E2B injection now done per-command in bash_exec)

- Fix unsafe output.setup_info?.agent_name access in ConnectIntegrationTool

- Add connect_integration_test.py: unknown provider, known provider structure,
  reason in message, session_id propagation, case-insensitive provider slug
2026-03-15 23:29:33 +07:00
Zamil Majdy
ac55ab619b fix(copilot): address coderabbitai nitpicks
- Use `type Props` instead of `interface Props` in ConnectIntegrationTool
- Simplify parseOutput: skip stringify→parse round-trip for objects
- Document why requires_auth=True despite user_id not being used
2026-03-15 23:14:54 +07:00
Zamil Majdy
a8014d1e92 fix(copilot): address sentry — refresh expired OAuth tokens, handle object output in parseOutput
- service.py: use IntegrationCredentialsManager.refresh_if_needed() instead of
  raw IntegrationCredentialsStore so expired GitHub OAuth tokens are refreshed
  before injection; falls back to stale token on refresh failure to avoid
  breaking the turn entirely (lock=False to avoid blocking the turn)
- ConnectIntegrationTool.tsx: parseOutput now handles both string and
  already-parsed object inputs, matching the RunBlock helper pattern
  used elsewhere in the codebase
2026-03-15 23:06:57 +07:00
Zamil Majdy
7de13c7713 fix(copilot): address self-review — GH_TOKEN OAuth preference, unknown provider error, baseline note scope
- service.py: two-pass loop in _get_github_token_for_user() to genuinely
  prefer OAuth2 tokens over API keys; use creds.type discriminator instead
  of isinstance to match codebase style
- connect_integration.py: return ErrorResponse (not SetupRequirementsResponse)
  for unknown providers so the frontend renders a proper error instead of a
  blank broken card; trim "Integration" suffix from agent_name to avoid
  "Connect GitHub Integration" redundancy
- prompting.py: move GitHub CLI / connect_integration guidance from
  _SHARED_TOOL_NOTES (baseline+SDK) to _SDK_TOOL_NOTES (SDK-only) since
  baseline mode has no subprocess, no gh CLI, and no connect_integration tool
- ConnectIntegrationTool.tsx: simplify parseOutput to short-circuit when raw
  is not a string, removing unnecessary JSON.stringify round-trip
2026-03-15 23:00:46 +07:00
Zamil Majdy
9358b525a0 feat(copilot): inject GH_TOKEN and add connect_integration tool for missing GitHub credentials
When the user has connected GitHub, GH_TOKEN is automatically injected
into the Claude Agent SDK subprocess environment so `gh` CLI works
without any manual auth step.

When GitHub is not connected, the copilot can call the new
connect_integration(provider="github") MCP tool, which surfaces the
same credentials setup card used by GitHub blocks — letting the user
connect their account inline without leaving the chat.

- backend: _get_github_token_for_user() fetches the user's GitHub
  credentials (OAuth2 or API key) and injects GH_TOKEN + GITHUB_TOKEN
  into sdk_env before the Claude Agent SDK subprocess starts
- backend: ConnectIntegrationTool MCP tool returns a
  SetupRequirementsResponse for any known provider (github for now)
- backend: prompting.py documents the gh CLI / connect_integration
  flow in _SHARED_TOOL_NOTES so the copilot knows when to call it
- frontend: ConnectIntegrationTool component renders the existing
  SetupRequirementsCard with a tailored retry instruction
- frontend: MessagePartRenderer dispatches tool-connect_integration
  to the new component
2026-03-15 22:55:08 +07:00
36 changed files with 4827 additions and 97 deletions

View File

@@ -11,7 +11,10 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import SchemaField
from backend.util.file import parse_data_uri, resolve_media_content
from backend.util.type import MediaFileType
from ._api import get_api
from ._auth import (
@@ -178,7 +181,8 @@ class FileOperation(StrEnum):
class FileOperationInput(TypedDict):
path: str
content: str
# MediaFileType is a str NewType — no runtime breakage for existing callers.
content: MediaFileType
operation: FileOperation
@@ -275,11 +279,11 @@ class GithubMultiFileCommitBlock(Block):
base_tree_sha = commit_data["tree"]["sha"]
# 3. Build tree entries for each file operation (blobs created concurrently)
async def _create_blob(content: str) -> str:
async def _create_blob(content: str, encoding: str = "utf-8") -> str:
blob_url = repo_url + "/git/blobs"
blob_response = await api.post(
blob_url,
json={"content": content, "encoding": "utf-8"},
json={"content": content, "encoding": encoding},
)
return blob_response.json()["sha"]
@@ -301,10 +305,19 @@ class GithubMultiFileCommitBlock(Block):
else:
upsert_files.append((path, file_op.get("content", "")))
# Create all blobs concurrently
# Create all blobs concurrently. Data URIs (from store_media_file)
# are sent as base64 blobs to preserve binary content.
if upsert_files:
async def _make_blob(content: str) -> str:
parsed = parse_data_uri(content)
if parsed is not None:
_, b64_payload = parsed
return await _create_blob(b64_payload, encoding="base64")
return await _create_blob(content)
blob_shas = await asyncio.gather(
*[_create_blob(content) for _, content in upsert_files]
*[_make_blob(content) for _, content in upsert_files]
)
for (path, _), blob_sha in zip(upsert_files, blob_shas):
tree_entries.append(
@@ -358,15 +371,36 @@ class GithubMultiFileCommitBlock(Block):
input_data: Input,
*,
credentials: GithubCredentials,
execution_context: ExecutionContext,
**kwargs,
) -> BlockOutput:
try:
# Resolve media references (workspace://, data:, URLs) to data
# URIs so _make_blob can send binary content correctly.
resolved_files: list[FileOperationInput] = []
for file_op in input_data.files:
content = file_op.get("content", "")
operation = FileOperation(file_op.get("operation", "upsert"))
if operation != FileOperation.DELETE:
content = await resolve_media_content(
MediaFileType(content),
execution_context,
return_format="for_external_api",
)
resolved_files.append(
FileOperationInput(
path=file_op["path"],
content=MediaFileType(content),
operation=operation,
)
)
sha, url = await self.multi_file_commit(
credentials,
input_data.repo_url,
input_data.branch,
input_data.commit_message,
input_data.files,
resolved_files,
)
yield "sha", sha
yield "url", url

View File

@@ -8,6 +8,7 @@ from backend.blocks.github.pull_requests import (
GithubMergePullRequestBlock,
prepare_pr_api_url,
)
from backend.data.execution import ExecutionContext
from backend.util.exceptions import BlockExecutionError
# ── prepare_pr_api_url tests ──
@@ -97,7 +98,11 @@ async def test_multi_file_commit_error_path():
"credentials": TEST_CREDENTIALS_INPUT,
}
with pytest.raises(BlockExecutionError, match="ref update failed"):
async for _ in block.execute(input_data, credentials=TEST_CREDENTIALS):
async for _ in block.execute(
input_data,
credentials=TEST_CREDENTIALS,
execution_context=ExecutionContext(),
):
pass

View File

@@ -11,6 +11,8 @@ from contextvars import ContextVar
from typing import TYPE_CHECKING
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
@@ -82,6 +84,17 @@ def resolve_sandbox_path(path: str) -> str:
return normalized
async def get_workspace_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped :class:`WorkspaceManager`.
Placed here (rather than in ``tools/workspace_files``) so that modules
like ``sdk/file_ref`` can import it without triggering the heavy
``tools/__init__`` import chain.
"""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Return True if *path* is within an allowed host-filesystem location.

View File

@@ -0,0 +1,162 @@
"""Integration credential lookup with per-process TTL cache.
Provides token retrieval for connected integrations so that copilot tools
(e.g. bash_exec) can inject auth tokens into the execution environment without
hitting the database on every command.
Cache semantics (handled automatically by TTLCache):
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
for users who have credentials and are running many bash commands.
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
on every E2B command for users who haven't connected an account yet, while
still picking up a newly-connected account within one minute.
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
least-recently-used entry when the limit is reached.
Multi-worker note: both caches are in-process only. Each worker/replica
maintains its own independent cache, so a credential fetch may be duplicated
across processes. This is acceptable for the current goal (reduce DB hits per
session per-process), but if cache efficiency across replicas becomes important
a shared cache (e.g. Redis) should be used instead.
"""
import logging
from typing import cast
from cachetools import TTLCache
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
register_creds_changed_hook,
)
logger = logging.getLogger(__name__)
# Maps provider slug → env var names to inject when the provider is connected.
# Add new providers here when adding integration support.
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
# must be updated when adding a new provider.
PROVIDER_ENV_VARS: dict[str, list[str]] = {
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
}
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
_CACHE_MAX_SIZE = 10_000
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
# because all callers (get_provider_token, invalidate_user_provider_cache) run
# exclusively on the asyncio event loop. There are no await points between a
# cache read and its corresponding write within any function, so no concurrent
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
# this path, a threading.RLock should be wrapped around these caches.
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
)
# Separate cache for "no credentials" results with a shorter TTL.
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
# entries. This avoids a lazy import inside creds_manager and eliminates the
# circular-import risk.
register_creds_changed_hook(invalidate_user_provider_cache)
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
# on every cache-miss call to get_provider_token().
_manager = IntegrationCredentialsManager()
async def get_provider_token(user_id: str, provider: str) -> str | None:
"""Return the user's access token for *provider*, or ``None`` if not connected.
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
command for users who haven't connected yet, while still picking up a
newly-connected account within one minute.
"""
cache_key = (user_id, provider)
if cache_key in _null_cache:
return None
if cached := _token_cache.get(cache_key):
return cached
manager = _manager
try:
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
except Exception:
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
return None
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
# full git access, while a public-data-only token lacks push/pull permission.
# lock=False — background injection; not worth a distributed lock acquisition.
oauth2_creds = sorted(
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
fresh = await manager.refresh_if_needed(
user_id, cast(OAuth2Credentials, creds), lock=False
)
token = fresh.access_token.get_secret_value()
except Exception:
logger.warning(
"Failed to refresh %s OAuth token for user %s; "
"falling back to potentially stale token",
provider,
user_id,
)
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
_token_cache[cache_key] = token
return token
# Pass 2: fall back to API key (no expiry, no refresh needed).
for creds in creds_list:
if creds.type == "api_key":
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
return None
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
"""Return env vars for all providers the user has connected.
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
Only providers with a stored credential contribute entries.
"""
env: dict[str, str] = {}
for provider, var_names in PROVIDER_ENV_VARS.items():
token = await get_provider_token(user_id, provider)
if token:
for var in var_names:
env[var] = token
return env

View File

@@ -0,0 +1,193 @@
"""Tests for integration_creds — TTL cache and token lookup paths."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_null_cache,
_token_cache,
get_integration_env_vars,
get_provider_token,
invalidate_user_provider_cache,
)
from backend.data.model import APIKeyCredentials, OAuth2Credentials
_USER = "user-integration-creds-test"
_PROVIDER = "github"
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
return APIKeyCredentials(
id="creds-api-key",
provider=_PROVIDER,
api_key=SecretStr(key),
title="Test API Key",
expires_at=None,
)
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
return OAuth2Credentials(
id="creds-oauth2",
provider=_PROVIDER,
title="Test OAuth",
access_token=SecretStr(token),
refresh_token=SecretStr("test-refresh"),
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
)
@pytest.fixture(autouse=True)
def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
class TestInvalidateUserProviderCache:
def test_removes_token_entry(self):
key = (_USER, _PROVIDER)
_token_cache[key] = "tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _token_cache
def test_removes_null_entry(self):
key = (_USER, _PROVIDER)
_null_cache[key] = True
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _null_cache
def test_noop_when_key_not_cached(self):
# Should not raise even when there is no cache entry.
invalidate_user_provider_cache("no-such-user", _PROVIDER)
def test_only_removes_targeted_key(self):
other_key = ("other-user", _PROVIDER)
_token_cache[other_key] = "other-tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_cached_token_without_db_hit(self):
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "cached-tok"
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_none_for_null_cached_provider(self):
_null_cache[(_USER, _PROVIDER)] = True
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_api_key_creds_returned_and_cached(self):
api_creds = _make_api_key_creds("my-api-key")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "my-api-key"
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_preferred_over_api_key(self):
oauth_creds = _make_oauth2_creds("oauth-tok")
api_creds = _make_api_key_creds("api-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
return_value=[api_creds, oauth_creds]
)
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "stale-oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
assert _null_cache.get((_USER, _PROVIDER)) is True
@pytest.mark.asyncio(loop_scope="session")
async def test_db_exception_returns_none_without_caching(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
side_effect=RuntimeError("db down")
)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
# DB errors are not cached — next call will retry
assert (_USER, _PROVIDER) not in _token_cache
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
"""Verify the TTL constants are set correctly for each cache."""
assert _null_cache.ttl == _NULL_CACHE_TTL
assert _token_cache.ttl == _TOKEN_CACHE_TTL
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):
_token_cache[(_USER, "github")] = "gh-tok"
result = await get_integration_env_vars(_USER)
for var in PROVIDER_ENV_VARS["github"]:
assert result[var] == "gh-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_dict_when_no_credentials(self):
_null_cache[(_USER, "github")] = True
result = await get_integration_env_vars(_USER)
assert result == {}

View File

@@ -52,17 +52,68 @@ Examples:
You can embed a reference inside any string argument, or use it as the entire
value. Multiple references in one argument are all expanded.
**Type coercion**: The platform automatically coerces expanded string values
to match the block's expected input types. For example, if a block expects
`list[list[str]]` and you pass a string containing a JSON array (e.g. from
an @@agptfile: expansion), the string will be parsed into the correct type.
**Structured data**: When the **entire** argument value is a single file
reference (no surrounding text), the platform automatically parses the file
content based on its extension or MIME type. Supported formats: JSON, JSONL,
CSV, TSV, YAML, TOML, Parquet, and Excel (.xlsx — first sheet only).
For example, pass `@@agptfile:workspace://<id>` where the file is a `.csv` and
the rows will be parsed into `list[list[str]]` automatically. If the format is
unrecognised or parsing fails, the content is returned as a plain string.
Legacy `.xls` files are **not** supported — only the modern `.xlsx` format.
**Type coercion**: The platform also coerces expanded values to match the
block's expected input types. For example, if a block expects `list[list[str]]`
and the expanded value is a JSON string, it will be parsed into the correct type.
### Media file inputs (format: "file")
Some block inputs accept media files — their schema shows `"format": "file"`.
These fields accept:
- **`workspace://<file_id>`** or **`workspace://<file_id>#<mime>`** — preferred
for large files (images, videos, PDFs). The platform passes the reference
directly to the block without reading the content into memory.
- **`data:<mime>;base64,<payload>`** — inline base64 data URI, suitable for
small files only.
When a block input has `format: "file"`, **pass the `workspace://` URI
directly as the value** (do NOT wrap it in `@@agptfile:`). This avoids large
payloads in tool arguments and preserves binary content (images, videos)
that would be corrupted by text encoding.
Example — committing an image file to GitHub:
```json
{
"files": [{
"path": "docs/hero.png",
"content": "workspace://abc123#image/png",
"operation": "upsert"
}]
}
```
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
"""
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.
- For operations that need broader access (e.g. private org repos, GitHub
Actions), pass the required scopes: e.g.
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
@@ -73,6 +124,7 @@ def _build_storage_supplement(
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
extra_notes: str = "",
) -> str:
"""Build storage/filesystem supplement for a specific environment.
@@ -87,6 +139,7 @@ def _build_storage_supplement(
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
extra_notes: Environment-specific notes appended after shared notes
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
@@ -120,12 +173,16 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
{_SHARED_TOOL_NOTES}"""
{_SHARED_TOOL_NOTES}{extra_notes}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
"""Local ephemeral storage (files lost between turns).
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
like gh will not work — no integration env-var notes are included.
"""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
@@ -143,7 +200,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
"""Cloud persistent sandbox (files survive across turns in session).
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
injected per command in bash_exec — include the CLI guidance notes.
"""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
@@ -158,6 +219,7 @@ def _get_cloud_sandbox_supplement() -> str:
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
extra_notes=_E2B_TOOL_NOTES,
)

View File

@@ -3,12 +3,45 @@
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
Submodule imports are deferred via PEP 562 ``__getattr__`` to break a
circular import cycle::
sdk/__init__ → tool_adapter → copilot.tools (TOOL_REGISTRY)
copilot.tools → run_block → sdk.file_ref (no cycle here, but…)
sdk/__init__ → service → copilot.prompting → copilot.tools (cycle!)
``tool_adapter`` uses ``TOOL_REGISTRY`` at **module level** to build the
static ``COPILOT_TOOL_NAMES`` list, so the import cannot be deferred to
function scope without a larger refactor (moving tool-name registration
to a separate lightweight module). The lazy-import pattern here is the
least invasive way to break the cycle while keeping module-level constants
intact.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
from typing import Any
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]
# Dispatch table for PEP 562 lazy imports. Each entry is a (module, attr)
# pair so new exports can be added without touching __getattr__ itself.
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"stream_chat_completion_sdk": (".service", "stream_chat_completion_sdk"),
"create_copilot_mcp_server": (".tool_adapter", "create_copilot_mcp_server"),
}
def __getattr__(name: str) -> Any:
entry = _LAZY_IMPORTS.get(name)
if entry is not None:
module_path, attr = entry
import importlib
module = importlib.import_module(module_path, package=__name__)
value = getattr(module, attr)
globals()[name] = value
return value
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -41,12 +41,20 @@ from typing import Any
from backend.copilot.context import (
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.workspace_files import get_manager
from backend.util.file import parse_workspace_uri
from backend.util.file_content_parser import (
BINARY_FORMATS,
MIME_TO_FORMAT,
PARSE_EXCEPTIONS,
infer_format_from_uri,
parse_file_content,
)
from backend.util.type import MediaFileType
class FileRefExpansionError(Exception):
@@ -74,6 +82,8 @@ _FILE_REF_RE = re.compile(
_MAX_EXPAND_CHARS = 200_000
# Maximum total characters across all @@agptfile: expansions in one string.
_MAX_TOTAL_EXPAND_CHARS = 1_000_000
# Maximum raw byte size for bare ref structured parsing (10 MB).
_MAX_BARE_REF_BYTES = 10_000_000
@dataclass
@@ -83,6 +93,11 @@ class FileRef:
end_line: int | None # 1-indexed, inclusive
# ---------------------------------------------------------------------------
# Public API (top-down: main functions first, helpers below)
# ---------------------------------------------------------------------------
def parse_file_ref(text: str) -> FileRef | None:
"""Return a :class:`FileRef` if *text* is a bare file reference token.
@@ -104,17 +119,6 @@ def parse_file_ref(text: str) -> FileRef | None:
return FileRef(uri=m.group(1), start_line=start, end_line=end)
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive)."""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
s = (start - 1) if start is not None else 0
e = end if end is not None else len(lines)
selected = list(itertools.islice(lines, s, e))
return "".join(selected)
async def read_file_bytes(
uri: str,
user_id: str | None,
@@ -130,27 +134,47 @@ async def read_file_bytes(
if plain.startswith("workspace://"):
if not user_id:
raise ValueError("workspace:// file references require authentication")
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
ws = parse_workspace_uri(plain)
try:
return await (
data = await (
manager.read_file(ws.file_ref)
if ws.is_path
else manager.read_file_by_id(ws.file_ref)
)
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
except (PermissionError, OSError) as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
except (AttributeError, TypeError, RuntimeError) as exc:
# AttributeError/TypeError: workspace manager returned an
# unexpected type or interface; RuntimeError: async runtime issues.
logger.warning("Unexpected error reading %s: %s", plain, exc)
raise ValueError(f"Failed to read {plain}: {exc}") from exc
# NOTE: Workspace API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
if is_allowed_local_path(plain, get_sdk_cwd()):
resolved = os.path.realpath(os.path.expanduser(plain))
try:
# Read with a one-byte overshoot to detect files that exceed the limit
# without a separate os.path.getsize call (avoids TOCTOU race).
with open(resolved, "rb") as fh:
return fh.read()
data = fh.read(_MAX_BARE_REF_BYTES + 1)
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large (>{_MAX_BARE_REF_BYTES} bytes, "
f"limit {_MAX_BARE_REF_BYTES})"
)
return data
except FileNotFoundError:
raise ValueError(f"File not found: {plain}")
except Exception as exc:
except OSError as exc:
raise ValueError(f"Failed to read {plain}: {exc}") from exc
sandbox = get_current_sandbox()
@@ -162,9 +186,33 @@ async def read_file_bytes(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
) from exc
try:
return bytes(await sandbox.files.read(remote, format="bytes"))
except Exception as exc:
data = bytes(await sandbox.files.read(remote, format="bytes"))
except (FileNotFoundError, OSError, UnicodeDecodeError) as exc:
raise ValueError(f"Failed to read from sandbox: {plain}: {exc}") from exc
except Exception as exc:
# E2B SDK raises SandboxException subclasses (NotFoundException,
# TimeoutException, NotEnoughSpaceException, etc.) which don't
# inherit from standard exceptions. Import lazily to avoid a
# hard dependency on e2b at module level.
try:
from e2b.exceptions import SandboxException # noqa: PLC0415
if isinstance(exc, SandboxException):
raise ValueError(
f"Failed to read from sandbox: {plain}: {exc}"
) from exc
except ImportError:
pass
# Re-raise unexpected exceptions (TypeError, AttributeError, etc.)
# so they surface as real bugs rather than being silently masked.
raise
# NOTE: E2B sandbox API does not support pre-read size checks;
# the full file is loaded before the size guard below.
if len(data) > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large ({len(data)} bytes, limit {_MAX_BARE_REF_BYTES})"
)
return data
raise ValueError(
f"Path is not allowed (not in workspace, sdk_cwd, or sandbox): {plain}"
@@ -178,15 +226,13 @@ async def resolve_file_ref(
) -> str:
"""Resolve a :class:`FileRef` to its text content."""
raw = await read_file_bytes(ref.uri, user_id, session)
return _apply_line_range(
raw.decode("utf-8", errors="replace"), ref.start_line, ref.end_line
)
return _apply_line_range(_to_str(raw), ref.start_line, ref.end_line)
async def expand_file_refs_in_string(
text: str,
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
raise_on_error: bool = False,
) -> str:
@@ -232,6 +278,9 @@ async def expand_file_refs_in_string(
if len(content) > _MAX_EXPAND_CHARS:
content = content[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
remaining = _MAX_TOTAL_EXPAND_CHARS - total_chars
# remaining == 0 means the budget was exactly exhausted by the
# previous ref. The elif below (len > remaining) won't catch
# this since 0 > 0 is false, so we need the <= 0 check.
if remaining <= 0:
content = "[file-ref budget exhausted: total expansion limit reached]"
elif len(content) > remaining:
@@ -252,13 +301,31 @@ async def expand_file_refs_in_string(
async def expand_file_refs_in_args(
args: dict[str, Any],
user_id: str | None,
session: "ChatSession",
session: ChatSession,
*,
input_schema: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Recursively expand ``@@agptfile:...`` references in tool call arguments.
String values are expanded in-place. Nested dicts and lists are
traversed. Non-string scalars are returned unchanged.
**Bare references** (the entire argument value is a single
``@@agptfile:...`` token with no surrounding text) are resolved and then
parsed according to the file's extension or MIME type. See
:mod:`backend.util.file_content_parser` for the full list of supported
formats (JSON, JSONL, CSV, TSV, YAML, TOML, Parquet, Excel).
When *input_schema* is provided and the target property has
``"type": "string"``, structured parsing is skipped — the raw file content
is returned as a plain string so blocks receive the original text.
If the format is unrecognised or parsing fails, the content is returned as
a plain string (the fallback).
**Embedded references** (``@@agptfile:`` mixed with other text) always
produce a plain string — structured parsing only applies to bare refs.
Raises :class:`FileRefExpansionError` if any reference fails to resolve,
so the tool is *not* executed with an error string as its input. The
caller (the MCP tool wrapper) should convert this into an MCP error
@@ -267,15 +334,382 @@ async def expand_file_refs_in_args(
if not args:
return args
async def _expand(value: Any) -> Any:
properties = (input_schema or {}).get("properties", {})
async def _expand(
value: Any,
*,
prop_schema: dict[str, Any] | None = None,
) -> Any:
"""Recursively expand a single argument value.
Strings are checked for ``@@agptfile:`` references and expanded
(bare refs get structured parsing; embedded refs get inline
substitution). Dicts and lists are traversed recursively,
threading the corresponding sub-schema from *prop_schema* so
that nested fields also receive correct type-aware expansion.
Non-string scalars pass through unchanged.
"""
if isinstance(value, str):
ref = parse_file_ref(value)
if ref is not None:
# MediaFileType fields: return the raw URI immediately —
# no file reading, no format inference, no content parsing.
if _is_media_file_field(prop_schema):
return ref.uri
fmt = infer_format_from_uri(ref.uri)
# Workspace URIs by ID (workspace://abc123) have no extension.
# When the MIME fragment is also missing, fall back to the
# workspace file manager's metadata for format detection.
if fmt is None and ref.uri.startswith("workspace://"):
fmt = await _infer_format_from_workspace(ref.uri, user_id, session)
return await _expand_bare_ref(ref, fmt, user_id, session, prop_schema)
# Not a bare ref — do normal inline expansion.
return await expand_file_refs_in_string(
value, user_id, session, raise_on_error=True
)
if isinstance(value, dict):
return {k: await _expand(v) for k, v in value.items()}
# When the schema says this is an object but doesn't define
# inner properties, skip expansion — the caller (e.g.
# RunBlockTool) will expand with the actual nested schema.
if (
prop_schema is not None
and prop_schema.get("type") == "object"
and "properties" not in prop_schema
):
return value
nested_props = (prop_schema or {}).get("properties", {})
return {
k: await _expand(v, prop_schema=nested_props.get(k))
for k, v in value.items()
}
if isinstance(value, list):
return [await _expand(item) for item in value]
items_schema = (prop_schema or {}).get("items")
return [await _expand(item, prop_schema=items_schema) for item in value]
return value
return {k: await _expand(v) for k, v in args.items()}
return {k: await _expand(v, prop_schema=properties.get(k)) for k, v in args.items()}
# ---------------------------------------------------------------------------
# Private helpers (used by the public functions above)
# ---------------------------------------------------------------------------
def _apply_line_range(text: str, start: int | None, end: int | None) -> str:
"""Slice *text* to the requested 1-indexed line range (inclusive).
When the requested range extends beyond the file, a note is appended
so the LLM knows it received the entire remaining content.
"""
if start is None and end is None:
return text
lines = text.splitlines(keepends=True)
total = len(lines)
s = (start - 1) if start is not None else 0
e = end if end is not None else total
selected = list(itertools.islice(lines, s, e))
result = "".join(selected)
if end is not None and end > total:
result += f"\n[Note: file has only {total} lines]\n"
return result
def _to_str(content: str | bytes) -> str:
"""Decode *content* to a string if it is bytes, otherwise return as-is."""
if isinstance(content, str):
return content
return content.decode("utf-8", errors="replace")
def _check_content_size(content: str | bytes) -> None:
"""Raise :class:`ValueError` if *content* exceeds the byte limit.
Raises ``ValueError`` (not ``FileRefExpansionError``) so that the caller
(``_expand_bare_ref``) can unify all resolution errors into a single
``except ValueError`` → ``FileRefExpansionError`` handler, keeping the
error-flow consistent with ``read_file_bytes`` and ``resolve_file_ref``.
For ``bytes``, the length is the byte count directly. For ``str``,
we encode to UTF-8 first because multi-byte characters (e.g. emoji)
mean the byte size can be up to 4x the character count.
"""
if isinstance(content, bytes):
size = len(content)
else:
char_len = len(content)
# Fast lower bound: UTF-8 byte count >= char count.
# If char count already exceeds the limit, reject immediately
# without allocating an encoded copy.
if char_len > _MAX_BARE_REF_BYTES:
size = char_len # real byte size is even larger
# Fast upper bound: each char is at most 4 UTF-8 bytes.
# If worst-case is still under the limit, skip encoding entirely.
elif char_len * 4 <= _MAX_BARE_REF_BYTES:
return
else:
# Edge case: char count is under limit but multibyte chars
# might push byte count over. Encode to get exact size.
size = len(content.encode("utf-8"))
if size > _MAX_BARE_REF_BYTES:
raise ValueError(
f"File too large for structured parsing "
f"({size} bytes, limit {_MAX_BARE_REF_BYTES})"
)
async def _infer_format_from_workspace(
uri: str,
user_id: str | None,
session: ChatSession,
) -> str | None:
"""Look up workspace file metadata to infer the format.
Workspace URIs by ID (``workspace://abc123``) have no file extension.
When the MIME fragment is also absent, we query the workspace file
manager for the file's stored MIME type and original filename.
"""
if not user_id:
return None
try:
ws = parse_workspace_uri(uri)
manager = await get_workspace_manager(user_id, session.session_id)
info = await (
manager.get_file_info(ws.file_ref)
if not ws.is_path
else manager.get_file_info_by_path(ws.file_ref)
)
if info is None:
return None
# Try MIME type first, then filename extension.
mime = (info.mime_type or "").split(";", 1)[0].strip().lower()
return MIME_TO_FORMAT.get(mime) or infer_format_from_uri(info.name)
except (
ValueError,
FileNotFoundError,
OSError,
PermissionError,
AttributeError,
TypeError,
):
# Expected failures: bad URI, missing file, permission denied, or
# workspace manager returning unexpected types. Propagate anything
# else (e.g. programming errors) so they don't get silently swallowed.
logger.debug("workspace metadata lookup failed for %s", uri, exc_info=True)
return None
def _is_media_file_field(prop_schema: dict[str, Any] | None) -> bool:
"""Return True if *prop_schema* describes a MediaFileType field (format: file)."""
if prop_schema is None:
return False
return (
prop_schema.get("type") == "string"
and prop_schema.get("format") == MediaFileType.string_format
)
async def _expand_bare_ref(
ref: FileRef,
fmt: str | None,
user_id: str | None,
session: ChatSession,
prop_schema: dict[str, Any] | None,
) -> Any:
"""Resolve and parse a bare ``@@agptfile:`` reference.
This is the structured-parsing path: the file is read, optionally parsed
according to *fmt*, and adapted to the target *prop_schema*.
Raises :class:`FileRefExpansionError` on resolution or parsing failure.
Note: MediaFileType fields (format: "file") are handled earlier in
``_expand`` to avoid unnecessary format inference and file I/O.
"""
try:
if fmt is not None and fmt in BINARY_FORMATS:
# Binary formats need raw bytes, not UTF-8 text.
# Line ranges are meaningless for binary formats (parquet/xlsx)
# — ignore them and parse full bytes. Warn so the caller/model
# knows the range was silently dropped.
if ref.start_line is not None or ref.end_line is not None:
logger.warning(
"Line range [%s-%s] ignored for binary format %s (%s); "
"binary formats are always parsed in full.",
ref.start_line,
ref.end_line,
fmt,
ref.uri,
)
content: str | bytes = await read_file_bytes(ref.uri, user_id, session)
else:
content = await resolve_file_ref(ref, user_id, session)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# For known formats this rejects files >10 MB before parsing.
# For unknown formats _MAX_EXPAND_CHARS (200K chars) below is stricter,
# but this check still guards the parsing path which has no char limit.
# _check_content_size raises ValueError, which we unify here just like
# resolution errors above.
try:
_check_content_size(content)
except ValueError as exc:
raise FileRefExpansionError(str(exc)) from exc
# When the schema declares this parameter as "string",
# return raw file content — don't parse into a structured
# type that would need json.dumps() serialisation.
expect_string = (prop_schema or {}).get("type") == "string"
if expect_string:
if isinstance(content, bytes):
raise FileRefExpansionError(
f"Cannot use {fmt} file as text input: "
f"binary formats (parquet, xlsx) must be passed "
f"to a block that accepts structured data (list/object), "
f"not a string-typed parameter."
)
return content
if fmt is not None:
# Use strict mode for binary formats so we surface the
# actual error (e.g. missing pyarrow/openpyxl, corrupt
# file) instead of silently returning garbled bytes.
strict = fmt in BINARY_FORMATS
try:
parsed = parse_file_content(content, fmt, strict=strict)
except PARSE_EXCEPTIONS as exc:
raise FileRefExpansionError(f"Failed to parse {fmt} file: {exc}") from exc
# Normalize bytes fallback to str so tools never
# receive raw bytes when parsing fails.
if isinstance(parsed, bytes):
parsed = _to_str(parsed)
return _adapt_to_schema(parsed, prop_schema)
# Unknown format — return as plain string, but apply
# the same per-ref character limit used by inline refs
# to prevent injecting unexpectedly large content.
text = _to_str(content)
if len(text) > _MAX_EXPAND_CHARS:
text = text[:_MAX_EXPAND_CHARS] + "\n... [truncated]"
return text
def _adapt_to_schema(parsed: Any, prop_schema: dict[str, Any] | None) -> Any:
"""Adapt a parsed file value to better fit the target schema type.
When the parser returns a natural type (e.g. dict from YAML, list from CSV)
that doesn't match the block's expected type, this function converts it to
a more useful representation instead of relying on pydantic's generic
coercion (which can produce awkward results like flattened dicts → lists).
Returns *parsed* unchanged when no adaptation is needed.
"""
if prop_schema is None:
return parsed
target_type = prop_schema.get("type")
# Dict → array: delegate to helper.
if isinstance(parsed, dict) and target_type == "array":
return _adapt_dict_to_array(parsed, prop_schema)
# List → object: delegate to helper (raises for non-tabular lists).
if isinstance(parsed, list) and target_type == "object":
return _adapt_list_to_object(parsed)
# Tabular list → Any (no type): convert to list of dicts.
# Blocks like FindInDictionaryBlock have `input: Any` which produces
# a schema with no "type" key. Tabular [[header],[rows]] is unusable
# for key lookup, but [{col: val}, ...] works with FindInDict's
# list-of-dicts branch (line 195-199 in data_manipulation.py).
if isinstance(parsed, list) and target_type is None and _is_tabular(parsed):
return _tabular_to_list_of_dicts(parsed)
return parsed
def _adapt_dict_to_array(parsed: dict, prop_schema: dict[str, Any]) -> Any:
"""Adapt a parsed dict to an array-typed field.
Extracts list-valued entries when the target item type is ``array``,
passes through unchanged when item type is ``string`` (lets pydantic error),
or wraps in ``[parsed]`` as a fallback.
"""
items_type = (prop_schema.get("items") or {}).get("type")
if items_type == "array":
# Target is List[List[Any]] — extract list-typed values from the
# dict as inner lists. E.g. YAML {"fruits": [{...},...]}} with
# ConcatenateLists (List[List[Any]]) → [[{...},...]].
list_values = [v for v in parsed.values() if isinstance(v, list)]
if list_values:
return list_values
if items_type == "string":
# Target is List[str] — wrapping a dict would give [dict]
# which can't coerce to strings. Return unchanged and let
# pydantic surface a clear validation error.
return parsed
# Fallback: wrap in a single-element list so the block gets [dict]
# instead of pydantic flattening keys/values into a flat list.
return [parsed]
def _adapt_list_to_object(parsed: list) -> Any:
"""Adapt a parsed list to an object-typed field.
Converts tabular lists to column-dicts; raises for non-tabular lists.
"""
if _is_tabular(parsed):
return _tabular_to_column_dict(parsed)
# Non-tabular list (e.g. a plain Python list from a YAML file) cannot
# be meaningfully coerced to an object. Raise explicitly so callers
# get a clear error rather than pydantic silently wrapping the list.
raise FileRefExpansionError(
"Cannot adapt a non-tabular list to an object-typed field. "
"Expected a tabular structure ([[header], [row1], ...]) or a dict."
)
def _is_tabular(parsed: Any) -> bool:
"""Check if parsed data is in tabular format: [[header], [row1], ...].
Uses isinstance checks because this is a structural type guard on
opaque parser output (Any), not duck typing. A Protocol wouldn't
help here — we need to verify exact list-of-lists shape.
"""
if not isinstance(parsed, list) or len(parsed) < 2:
return False
header = parsed[0]
if not isinstance(header, list) or not header:
return False
if not all(isinstance(h, str) for h in header):
return False
return all(isinstance(row, list) for row in parsed[1:])
def _tabular_to_list_of_dicts(parsed: list) -> list[dict[str, Any]]:
"""Convert [[header], [row1], ...] → [{header[0]: row[0], ...}, ...].
Ragged rows (fewer columns than the header) get None for missing values.
Extra values beyond the header length are silently dropped.
"""
header = parsed[0]
return [
dict(itertools.zip_longest(header, row[: len(header)], fillvalue=None))
for row in parsed[1:]
]
def _tabular_to_column_dict(parsed: list) -> dict[str, list]:
"""Convert [[header], [row1], ...] → {"col1": [val1, ...], ...}.
Ragged rows (fewer columns than the header) get None for missing values,
ensuring all columns have equal length.
"""
header = parsed[0]
return {
col: [row[i] if i < len(row) else None for row in parsed[1:]]
for i, col in enumerate(header)
}

View File

@@ -175,6 +175,199 @@ async def test_expand_args_replaces_file_ref_in_nested_dict():
assert result["count"] == 42
# ---------------------------------------------------------------------------
# expand_file_refs_in_args — bare ref structured parsing
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_bare_ref_json_returns_parsed_dict():
"""Bare ref to a .json file returns parsed dict, not raw string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value", "count": 42}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == {"key": "value", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_csv_returns_parsed_table():
"""Bare ref to a .csv file returns list[list[str]] table."""
with tempfile.TemporaryDirectory() as sdk_cwd:
csv_file = os.path.join(sdk_cwd, "data.csv")
with open(csv_file, "w") as f:
f.write("Name,Score\nAlice,90\nBob,85")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"input": f"@@agptfile:{csv_file}"},
user_id="u1",
session=_make_session(),
)
assert result["input"] == [
["Name", "Score"],
["Alice", "90"],
["Bob", "85"],
]
@pytest.mark.asyncio
async def test_bare_ref_unknown_extension_returns_string():
"""Bare ref to a file with unknown extension returns plain string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
txt_file = os.path.join(sdk_cwd, "readme.txt")
with open(txt_file, "w") as f:
f.write("plain text content")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{txt_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "plain text content"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_bare_ref_invalid_json_falls_back_to_string():
"""Bare ref to a .json file with invalid JSON falls back to string."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "bad.json")
with open(json_file, "w") as f:
f.write("not valid json {{{")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{json_file}"},
user_id="u1",
session=_make_session(),
)
assert result["data"] == "not valid json {{{"
assert isinstance(result["data"], str)
@pytest.mark.asyncio
async def test_embedded_ref_always_returns_string_even_for_json():
"""Embedded ref (text around it) returns plain string, not parsed JSON."""
with tempfile.TemporaryDirectory() as sdk_cwd:
json_file = os.path.join(sdk_cwd, "data.json")
with open(json_file, "w") as f:
f.write('{"key": "value"}')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"data": f"prefix @@agptfile:{json_file} suffix"},
user_id="u1",
session=_make_session(),
)
assert isinstance(result["data"], str)
assert result["data"].startswith("prefix ")
assert result["data"].endswith(" suffix")
@pytest.mark.asyncio
async def test_bare_ref_yaml_returns_parsed_dict():
"""Bare ref to a .yaml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
yaml_file = os.path.join(sdk_cwd, "config.yaml")
with open(yaml_file, "w") as f:
f.write("name: test\ncount: 42\n")
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{yaml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
@pytest.mark.asyncio
async def test_bare_ref_binary_with_line_range_ignores_range():
"""Bare ref to a binary file (.parquet) with line range parses the full file.
Binary formats (parquet, xlsx) ignore line ranges — the full content is
parsed and the range is silently dropped with a log warning.
"""
try:
import pandas as pd
except ImportError:
pytest.skip("pandas not installed")
try:
import pyarrow # noqa: F401 # pyright: ignore[reportMissingImports]
except ImportError:
pytest.skip("pyarrow not installed")
with tempfile.TemporaryDirectory() as sdk_cwd:
parquet_file = os.path.join(sdk_cwd, "data.parquet")
import io as _io
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
buf = _io.BytesIO()
df.to_parquet(buf, index=False)
with open(parquet_file, "wb") as f:
f.write(buf.getvalue())
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
# Line range [1-2] should be silently ignored for binary formats.
result = await expand_file_refs_in_args(
{"data": f"@@agptfile:{parquet_file}[1-2]"},
user_id="u1",
session=_make_session(),
)
# Full file is returned despite the line range.
assert result["data"] == [["A", "B"], [1, 4], [2, 5], [3, 6]]
@pytest.mark.asyncio
async def test_bare_ref_toml_returns_parsed_dict():
"""Bare ref to a .toml file returns parsed dict."""
with tempfile.TemporaryDirectory() as sdk_cwd:
toml_file = os.path.join(sdk_cwd, "config.toml")
with open(toml_file, "w") as f:
f.write('name = "test"\ncount = 42\n')
with patch("backend.copilot.context._current_sdk_cwd") as mock_cwd_var:
mock_cwd_var.get.return_value = sdk_cwd
result = await expand_file_refs_in_args(
{"config": f"@@agptfile:{toml_file}"},
user_id="u1",
session=_make_session(),
)
assert result["config"] == {"name": "test", "count": 42}
# ---------------------------------------------------------------------------
# _read_file_handler — extended to accept workspace:// and local paths
# ---------------------------------------------------------------------------
@@ -219,7 +412,7 @@ async def test_read_file_handler_workspace_uri():
"backend.copilot.sdk.tool_adapter.get_execution_context",
return_value=("user-1", mock_session),
), patch(
"backend.copilot.sdk.file_ref.get_manager",
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await _read_file_handler(
@@ -276,7 +469,7 @@ async def test_read_file_bytes_workspace_virtual_path():
mock_manager.read_file.return_value = b"virtual path content"
with patch(
"backend.copilot.sdk.file_ref.get_manager",
"backend.copilot.sdk.file_ref.get_workspace_manager",
new=AsyncMock(return_value=mock_manager),
):
result = await read_file_bytes("workspace:///reports/q1.md", "user-1", session)

File diff suppressed because it is too large Load Diff

View File

@@ -29,6 +29,7 @@ from langfuse import propagate_attributes
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -62,7 +63,6 @@ from ..service import (
)
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tools.workspace_files import get_manager
from ..tracking import track_user_message
from .compaction import CompactionTracker, filter_compaction_messages
from .response_adapter import SDKResponseAdapter
@@ -565,7 +565,7 @@ async def _prepare_file_attachments(
return empty
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
except Exception:
logger.warning(
"Failed to create workspace manager for file attachments",
@@ -769,7 +769,7 @@ async def stream_chat_completion_sdk(
)
return None
try:
return await get_or_create_sandbox(
sandbox = await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
@@ -783,7 +783,9 @@ async def stream_chat_completion_sdk(
e2b_err,
exc_info=True,
)
return None
return None
return sandbox
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""

View File

@@ -20,7 +20,7 @@ class _FakeFileInfo:
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
_PATCH_TARGET = "backend.copilot.sdk.service.get_workspace_manager"
class TestPrepareFileAttachments:

View File

@@ -347,7 +347,7 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
:func:`get_sdk_disallowed_tools`.
"""
def _truncating(fn, tool_name: str):
def _truncating(fn, tool_name: str, input_schema: dict[str, Any] | None = None):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
@@ -361,7 +361,9 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
user_id, session = get_execution_context()
if session is not None:
try:
args = await expand_file_refs_in_args(args, user_id, session)
args = await expand_file_refs_in_args(
args, user_id, session, input_schema=input_schema
)
except FileRefExpansionError as exc:
return _mcp_error(
f"@@agptfile: reference could not be resolved: {exc}. "
@@ -389,11 +391,12 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
schema = _build_input_schema(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
schema,
)(_truncating(handler, tool_name, input_schema=schema))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
"connect_integration": ConnectIntegrationTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),

View File

@@ -32,6 +32,7 @@ import shutil
import tempfile
from typing import Any
from backend.copilot.context import get_workspace_manager
from backend.copilot.model import ChatSession
from backend.util.request import validate_url_host
@@ -43,7 +44,6 @@ from .models import (
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
@@ -194,7 +194,7 @@ async def _save_browser_state(
),
}
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
@@ -218,7 +218,7 @@ async def _restore_browser_state(
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
manager = await get_workspace_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
@@ -360,7 +360,7 @@ async def close_browser_session(session_name: str, user_id: str | None = None) -
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
manager = await get_workspace_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)

View File

@@ -897,7 +897,7 @@ class TestHasLocalSession:
# _save_browser_state
# ---------------------------------------------------------------------------
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_manager"
_GET_MANAGER = "backend.copilot.tools.agent_browser.get_workspace_manager"
def _make_mock_manager():

View File

@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.integration_creds import get_integration_env_vars
from backend.copilot.model import ChatSession
from .base import BaseTool
@@ -96,7 +97,9 @@ class BashExecTool(BaseTool):
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
return await self._execute_on_e2b(
sandbox, command, timeout, session_id, user_id
)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
@@ -133,14 +136,27 @@ class BashExecTool(BaseTool):
command: str,
timeout: int,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
"""Execute *command* on the E2B sandbox via commands.run().
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
for any user with connected accounts. E2B has full internet access, so
CLI tools like ``gh`` work without manual authentication.
"""
envs: dict[str, str] = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
}
if user_id is not None:
integration_env = await get_integration_env_vars(user_id)
envs.update(integration_env)
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
envs=envs,
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",

View File

@@ -0,0 +1,78 @@
"""Tests for BashExecTool — E2B path with token injection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from ._test_data import make_session
from .bash_exec import BashExecTool
from .models import BashExecResponse
_USER = "user-bash-exec-test"
def _make_tool() -> BashExecTool:
return BashExecTool()
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
result = MagicMock()
result.exit_code = exit_code
result.stdout = stdout
result.stderr = stderr
sandbox = MagicMock()
sandbox.commands.run = AsyncMock(return_value=result)
return sandbox
class TestBashExecE2BTokenInjection:
@pytest.mark.asyncio(loop_scope="session")
async def test_token_injected_when_user_id_set(self):
"""When user_id is provided, integration env vars are merged into sandbox envs."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=_USER,
)
mock_get_env.assert_awaited_once_with(_USER)
call_kwargs = sandbox.commands.run.call_args[1]
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
assert isinstance(result, BashExecResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_no_token_injection_when_user_id_is_none(self):
"""When user_id is None, get_integration_env_vars must NOT be called."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=None,
)
mock_get_env.assert_not_called()
call_kwargs = sandbox.commands.run.call_args[1]
assert "GH_TOKEN" not in call_kwargs["envs"]
assert isinstance(result, BashExecResponse)

View File

@@ -0,0 +1,215 @@
"""Tool for prompting the user to connect a required integration.
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
"authentication required"), it calls this tool to surface the credentials
setup card in the chat — the same UI that appears when a GitHub block runs
without configured credentials.
"""
import functools
from typing import Any, TypedDict
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import (
ErrorResponse,
ResponseType,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .base import BaseTool
class _ProviderInfo(TypedDict):
name: str
types: list[str]
# Default OAuth scopes requested when the agent doesn't specify any.
scopes: list[str]
class _CredentialEntry(TypedDict):
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
id: str
title: str
provider: str
provider_name: str
type: str
types: list[str]
scopes: list[str]
@functools.lru_cache(maxsize=1)
def _is_github_oauth_configured() -> bool:
"""Return True if GitHub OAuth env vars are set.
Evaluated lazily (not at import time) to avoid triggering Secrets() during
module import, which can fail in environments where secrets are not loaded.
"""
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
return GITHUB_OAUTH_IS_CONFIGURED
# Registry of known providers: name + supported credential types for the UI.
# When adding a new provider, also add its env var names to
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
def _get_provider_info() -> dict[str, _ProviderInfo]:
"""Build the provider registry, evaluating OAuth config lazily."""
return {
"github": {
"name": "GitHub",
"types": (
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
),
# Default: repo scope covers clone/push/pull for public and private repos.
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
"scopes": ["repo"],
},
}
class ConnectIntegrationTool(BaseTool):
"""Surface the credentials setup UI when an integration is not connected."""
@property
def name(self) -> str:
return "connect_integration"
@property
def description(self) -> str:
return (
"Prompt the user to connect a required integration (e.g. GitHub). "
"Call this when an external CLI or API call fails because the user "
"has not connected the relevant account. "
"The tool surfaces a credentials setup card in the chat so the user "
"can authenticate without leaving the page. "
"After the user connects the account, retry the operation. "
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
"automatically injected per-command in bash_exec — no manual export needed. "
"In local bubblewrap mode network is isolated so GitHub CLI commands "
"will still fail after connecting; inform the user of this limitation."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": (
"Integration provider slug, e.g. 'github'. "
"Must be one of the supported providers."
),
"enum": list(_get_provider_info().keys()),
},
"reason": {
"type": "string",
"description": (
"Brief explanation of why the integration is needed, "
"shown to the user in the setup card."
),
"maxLength": 500,
},
"scopes": {
"type": "array",
"items": {"type": "string"},
"description": (
"OAuth scopes to request. Omit to use the provider default. "
"Add extra scopes when you need more access — e.g. for GitHub: "
"'repo' (clone/push/pull), 'read:org' (org membership), "
"'workflow' (GitHub Actions). "
"Requesting only the scopes you actually need is best practice."
),
},
},
"required": ["provider"],
}
@property
def requires_auth(self) -> bool:
# Require auth so only authenticated users can trigger the setup card.
# The card itself is user-agnostic (no per-user data needed), so
# user_id is intentionally unused in _execute.
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
provider_info = _get_provider_info()
info = provider_info.get(provider)
if not info:
supported = ", ".join(f"'{p}'" for p in provider_info)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
)
provider_name: str = info["name"]
supported_types: list[str] = info["types"]
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = info["scopes"]
seen: set[str] = set()
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
f"To continue, please connect your {provider_name} account.",
]
if reason:
message_parts.append(reason)
credential_entry: _CredentialEntry = {
"id": field_key,
"title": f"{provider_name} Credentials",
"provider": provider,
"provider_name": provider_name,
"type": supported_types[0],
"types": supported_types,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
return SetupRequirementsResponse(
type=ResponseType.SETUP_REQUIREMENTS,
message=" ".join(message_parts),
session_id=session_id,
setup_info=SetupInfo(
agent_id=f"connect_{provider}",
agent_name=provider_name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials,
ready_to_run=False,
),
requirements={
"credentials": [missing_credentials[field_key]],
"inputs": [],
"execution_modes": [],
},
),
)

View File

@@ -0,0 +1,135 @@
"""Tests for ConnectIntegrationTool."""
import pytest
from ._test_data import make_session
from .connect_integration import ConnectIntegrationTool
from .models import ErrorResponse, SetupRequirementsResponse
_TEST_USER_ID = "test-user-connect-integration"
class TestConnectIntegrationTool:
def _make_tool(self) -> ConnectIntegrationTool:
return ConnectIntegrationTool()
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
assert "nonexistent" in result.message
assert "github" in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_provider_returns_setup_response(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_name == "GitHub"
assert result.setup_info.agent_id == "connect_github"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_has_missing_credentials_in_readiness(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
readiness = result.setup_info.user_readiness
assert readiness.has_all_credentials is False
assert readiness.ready_to_run is False
assert "github_credentials" in readiness.missing_credentials
@pytest.mark.asyncio(loop_scope="session")
async def test_github_requirements_include_credential_entry(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
creds = result.setup_info.requirements["credentials"]
assert len(creds) == 1
assert creds[0]["provider"] == "github"
assert creds[0]["id"] == "github_credentials"
@pytest.mark.asyncio(loop_scope="session")
async def test_reason_appears_in_message(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
reason = "Needed to create a pull request."
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
)
assert isinstance(result, SetupRequirementsResponse)
assert reason in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_session_id_propagated(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.session_id == session.session_id
@pytest.mark.asyncio(loop_scope="session")
async def test_provider_case_insensitive(self):
"""Provider slug is normalised to lowercase before lookup."""
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="GitHub"
)
assert isinstance(result, SetupRequirementsResponse)
def test_tool_name(self):
assert ConnectIntegrationTool().name == "connect_integration"
def test_requires_auth(self):
assert ConnectIntegrationTool().requires_auth is True
@pytest.mark.asyncio(loop_scope="session")
async def test_unauthenticated_user_gets_need_login_response(self):
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
This verifies that the requires_auth guard in BaseTool.execute() fires
before _execute() is called, so unauthenticated callers cannot probe
which integrations are configured.
"""
import json
tool = self._make_tool()
# Session still needs a user_id string; the None is passed to execute()
# to simulate an unauthenticated call.
session = make_session(user_id=_TEST_USER_ID)
result = await tool.execute(
user_id=None,
session=session,
tool_call_id="test-call-id",
provider="github",
)
raw = result.output
output = json.loads(raw) if isinstance(raw, str) else raw
assert output.get("type") == "need_login"
assert result.success is False

View File

@@ -12,6 +12,7 @@ from backend.copilot.constants import (
COPILOT_SESSION_PREFIX,
)
from backend.copilot.model import ChatSession
from backend.copilot.sdk.file_ref import FileRefExpansionError, expand_file_refs_in_args
from backend.data.db_accessors import review_db
from backend.data.execution import ExecutionContext
@@ -197,6 +198,29 @@ class RunBlockTool(BaseTool):
session_id=session_id,
)
# Expand @@agptfile: refs in input_data with the block's input
# schema. The generic _truncating wrapper skips opaque object
# properties (input_data has no declared inner properties in the
# tool schema), so file ref tokens are still intact here.
# Using the block's schema lets us return raw text for string-typed
# fields and parsed structures for list/dict-typed fields.
if input_data:
try:
input_data = await expand_file_refs_in_args(
input_data,
user_id,
session,
input_schema=input_schema,
)
except FileRefExpansionError as exc:
return ErrorResponse(
message=(
f"Failed to resolve file reference: {exc}. "
"Ensure the file exists before referencing it."
),
session_id=session_id,
)
if missing_credentials:
# Return setup requirements response with missing credentials
credentials_fields_info = block.input_schema.get_credentials_fields_info()

View File

@@ -10,11 +10,11 @@ from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_workspace_manager,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
from backend.copilot.tools.sandbox import make_session_path
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
@@ -218,12 +218,6 @@ def _is_text_mime(mime_type: str) -> bool:
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
async def get_manager(user_id: str, session_id: str) -> WorkspaceManager:
"""Create a session-scoped WorkspaceManager."""
workspace = await workspace_db().get_or_create_workspace(user_id)
return WorkspaceManager(user_id, workspace.id, session_id)
async def _resolve_file(
manager: WorkspaceManager,
file_id: str | None,
@@ -386,7 +380,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
files = await manager.list_files(
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
)
@@ -536,7 +530,7 @@ class ReadWorkspaceFileTool(BaseTool):
)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved
@@ -772,7 +766,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
await scan_content_safe(content, filename=filename)
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
rec = await manager.write_file(
content=content,
filename=filename,
@@ -899,7 +893,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
manager = await get_manager(user_id, session_id)
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
return resolved

View File

@@ -25,6 +25,35 @@ logger = logging.getLogger(__name__)
settings = Settings()
_on_creds_changed: Callable[[str, str], None] | None = None
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
"""Register a callback invoked after any credential is created/updated/deleted.
The callback receives ``(user_id, provider)`` and should be idempotent.
Only one hook can be registered at a time; calling this again replaces the
previous hook. Intended to be called once at application startup by the
copilot module to bust its token cache without creating an import cycle.
"""
global _on_creds_changed
_on_creds_changed = hook
def _bust_copilot_cache(user_id: str, provider: str) -> None:
"""Invoke the registered hook (if any) to bust downstream token caches."""
if _on_creds_changed is not None:
try:
_on_creds_changed(user_id, provider)
except Exception:
logger.warning(
"Credential-change hook failed for user=%s provider=%s",
user_id,
provider,
exc_info=True,
)
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
@@ -69,7 +98,11 @@ class IntegrationCredentialsManager:
return self._locks
async def create(self, user_id: str, credentials: Credentials) -> None:
return await self.store.add_creds(user_id, credentials)
result = await self.store.add_creds(user_id, credentials)
# Bust the copilot token cache so that the next bash_exec picks up the
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
_bust_copilot_cache(user_id, credentials.provider)
return result
async def exists(self, user_id: str, credentials_id: str) -> bool:
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
@@ -156,6 +189,8 @@ class IntegrationCredentialsManager:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
# Bust copilot cache so the refreshed token is picked up immediately.
_bust_copilot_cache(user_id, fresh_credentials.provider)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
@@ -168,10 +203,17 @@ class IntegrationCredentialsManager:
async def update(self, user_id: str, updated: Credentials) -> None:
async with self._locked(user_id, updated.id):
await self.store.update_creds(user_id, updated)
# Bust the copilot token cache so the updated credential is picked up immediately.
_bust_copilot_cache(user_id, updated.provider)
async def delete(self, user_id: str, credentials_id: str) -> None:
async with self._locked(user_id, credentials_id):
# Read inside the lock to avoid TOCTOU — another coroutine could
# delete the same credential between the read and the delete.
creds = await self.store.get_creds_by_id(user_id, credentials_id)
await self.store.delete_creds_by_id(user_id, credentials_id)
if creds:
_bust_copilot_cache(user_id, creds.provider)
# -- Locking utilities -- #

View File

@@ -275,13 +275,12 @@ async def store_media_file(
# Process file
elif file.startswith("data:"):
# Data URI
match = re.match(r"^data:([^;]+);base64,(.*)$", file, re.DOTALL)
if not match:
parsed_uri = parse_data_uri(file)
if parsed_uri is None:
raise ValueError(
"Invalid data URI format. Expected data:<mime>;base64,<data>"
)
mime_type = match.group(1).strip().lower()
b64_content = match.group(2).strip()
mime_type, b64_content = parsed_uri
# Generate filename and decode
extension = _extension_from_mime(mime_type)
@@ -415,13 +414,70 @@ def get_dir_size(path: Path) -> int:
return total
async def resolve_media_content(
content: MediaFileType,
execution_context: "ExecutionContext",
*,
return_format: MediaReturnFormat,
) -> MediaFileType:
"""Resolve a ``MediaFileType`` value if it is a media reference, pass through otherwise.
Convenience wrapper around :func:`is_media_file_ref` + :func:`store_media_file`.
Plain text content (source code, filenames) is returned unchanged. Media
references (``data:``, ``workspace://``, ``http(s)://``) are resolved via
:func:`store_media_file` using *return_format*.
Use this when a block field is typed as ``MediaFileType`` but may contain
either literal text or a media reference.
"""
if not content or not is_media_file_ref(content):
return content
return await store_media_file(
content, execution_context, return_format=return_format
)
def is_media_file_ref(value: str) -> bool:
"""Return True if *value* looks like a ``MediaFileType`` reference.
Detects data URIs, workspace:// references, and HTTP(S) URLs — the
formats accepted by :func:`store_media_file`. Plain text content
(e.g. source code, filenames) returns False.
Known limitation: HTTP(S) URL detection is heuristic. Any string that
starts with ``http://`` or ``https://`` is treated as a media URL, even
if it appears as a URL inside source-code comments or documentation.
Blocks that produce source code or Markdown as output may therefore
trigger false positives. Callers that need higher precision should
inspect the string further (e.g. verify the URL is reachable or has a
media-friendly extension).
Note: this does *not* match local file paths, which are ambiguous
(could be filenames or actual paths). Blocks that need to resolve
local paths should check for them separately.
"""
return value.startswith(("data:", "workspace://", "http://", "https://"))
def parse_data_uri(value: str) -> tuple[str, str] | None:
"""Parse a ``data:<mime>;base64,<payload>`` URI.
Returns ``(mime_type, base64_payload)`` if *value* is a valid data URI,
or ``None`` if it is not.
"""
match = re.match(r"^data:([^;]+);base64,(.*)$", value, re.DOTALL)
if not match:
return None
return match.group(1).strip().lower(), match.group(2).strip()
def get_mime_type(file: str) -> str:
"""
Get the MIME type of a file, whether it's a data URI, URL, or local path.
"""
if file.startswith("data:"):
match = re.match(r"^data:([^;]+);base64,", file)
return match.group(1) if match else "application/octet-stream"
parsed_uri = parse_data_uri(file)
return parsed_uri[0] if parsed_uri else "application/octet-stream"
elif file.startswith(("http://", "https://")):
parsed_url = urlparse(file)

View File

@@ -0,0 +1,375 @@
"""Parse file content into structured Python objects based on file format.
Used by the ``@@agptfile:`` expansion system to eagerly parse well-known file
formats into native Python types *before* schema-driven coercion runs. This
lets blocks with ``Any``-typed inputs receive structured data rather than raw
strings, while blocks expecting strings get the value coerced back via
``convert()``.
Supported formats:
- **JSON** (``.json``) — arrays and objects are promoted; scalars stay as strings
- **JSON Lines** (``.jsonl``, ``.ndjson``) — each non-empty line parsed as JSON;
when all lines are dicts with the same keys (tabular data), output is
``list[list[Any]]`` with a header row, consistent with CSV/Parquet/Excel;
otherwise returns a plain ``list`` of parsed values
- **CSV** (``.csv``) — ``csv.reader`` → ``list[list[str]]``
- **TSV** (``.tsv``) — tab-delimited → ``list[list[str]]``
- **YAML** (``.yaml``, ``.yml``) — parsed via PyYAML; containers only
- **TOML** (``.toml``) — parsed via stdlib ``tomllib``
- **Parquet** (``.parquet``) — via pandas/pyarrow → ``list[list[Any]]`` with header row
- **Excel** (``.xlsx``) — via pandas/openpyxl → ``list[list[Any]]`` with header row
(legacy ``.xls`` is **not** supported — only the modern OOXML format)
The **fallback contract** is enforced by :func:`parse_file_content`, not by
individual parser functions. If any parser raises, ``parse_file_content``
catches the exception and returns the original content unchanged (string for
text formats, bytes for binary formats). Callers should never see an
exception from the public API when ``strict=False``.
"""
import csv
import io
import json
import logging
import tomllib
import zipfile
from collections.abc import Callable
# posixpath.splitext handles forward-slash URI paths correctly on all platforms,
# unlike os.path.splitext which uses platform-native separators.
from posixpath import splitext
from typing import Any
import yaml
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extension / MIME → format label mapping
# ---------------------------------------------------------------------------
_EXT_TO_FORMAT: dict[str, str] = {
".json": "json",
".jsonl": "jsonl",
".ndjson": "jsonl",
".csv": "csv",
".tsv": "tsv",
".yaml": "yaml",
".yml": "yaml",
".toml": "toml",
".parquet": "parquet",
".xlsx": "xlsx",
}
MIME_TO_FORMAT: dict[str, str] = {
"application/json": "json",
"application/x-ndjson": "jsonl",
"application/jsonl": "jsonl",
"text/csv": "csv",
"text/tab-separated-values": "tsv",
"application/x-yaml": "yaml",
"application/yaml": "yaml",
"text/yaml": "yaml",
"application/toml": "toml",
"application/vnd.apache.parquet": "parquet",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
}
# Formats that require raw bytes rather than decoded text.
BINARY_FORMATS: frozenset[str] = frozenset({"parquet", "xlsx"})
# ---------------------------------------------------------------------------
# Public API (top-down: main functions first, helpers below)
# ---------------------------------------------------------------------------
def infer_format_from_uri(uri: str) -> str | None:
"""Return a format label based on URI extension or MIME fragment.
Returns ``None`` when the format cannot be determined — the caller should
fall back to returning the content as a plain string.
"""
# 1. Check MIME fragment (workspace://abc123#application/json)
if "#" in uri:
_, fragment = uri.rsplit("#", 1)
fmt = MIME_TO_FORMAT.get(fragment.lower())
if fmt:
return fmt
# 2. Check file extension from the path portion.
# Strip the fragment first so ".json#mime" doesn't confuse splitext.
path = uri.split("#")[0].split("?")[0]
_, ext = splitext(path)
fmt = _EXT_TO_FORMAT.get(ext.lower())
if fmt is not None:
return fmt
# Legacy .xls is not supported — map it so callers can produce a
# user-friendly error instead of returning garbled binary.
if ext.lower() == ".xls":
return "xls"
return None
def parse_file_content(content: str | bytes, fmt: str, *, strict: bool = False) -> Any:
"""Parse *content* according to *fmt* and return a native Python value.
When *strict* is ``False`` (default), returns the original *content*
unchanged if *fmt* is not recognised or parsing fails for any reason.
This mode **never raises**.
When *strict* is ``True``, parsing errors are propagated to the caller.
Unrecognised formats or type mismatches (e.g. text for a binary format)
still return *content* unchanged without raising.
"""
if fmt == "xls":
return (
"[Unsupported format] Legacy .xls files are not supported. "
"Please re-save the file as .xlsx (Excel 2007+) and upload again."
)
try:
if fmt in BINARY_FORMATS:
parser = _BINARY_PARSERS.get(fmt)
if parser is None:
return content
if isinstance(content, str):
# Caller gave us text for a binary format — can't parse.
return content
return parser(content)
parser = _TEXT_PARSERS.get(fmt)
if parser is None:
return content
if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace")
return parser(content)
except PARSE_EXCEPTIONS:
if strict:
raise
logger.debug("Structured parsing failed for format=%s, falling back", fmt)
return content
# ---------------------------------------------------------------------------
# Exception loading helpers
# ---------------------------------------------------------------------------
def _load_openpyxl_exception() -> type[Exception]:
"""Return openpyxl's InvalidFileException, raising ImportError if absent."""
from openpyxl.utils.exceptions import InvalidFileException # noqa: PLC0415
return InvalidFileException
def _load_arrow_exception() -> type[Exception]:
"""Return pyarrow's ArrowException, raising ImportError if absent."""
from pyarrow import ArrowException # noqa: PLC0415
return ArrowException
def _optional_exc(loader: "Callable[[], type[Exception]]") -> "type[Exception] | None":
"""Return the exception class from *loader*, or ``None`` if the dep is absent."""
try:
return loader()
except ImportError:
return None
# Exception types that can be raised during file content parsing.
# Shared between ``parse_file_content`` (which catches them in non-strict mode)
# and ``file_ref._expand_bare_ref`` (which re-raises them as FileRefExpansionError).
#
# Optional-dependency exception types are loaded via a helper that raises
# ``ImportError`` at *parse time* rather than silently becoming ``None`` here.
# This ensures mypy sees clean types and missing deps surface as real errors.
PARSE_EXCEPTIONS: tuple[type[BaseException], ...] = tuple(
exc
for exc in (
json.JSONDecodeError,
csv.Error,
yaml.YAMLError,
tomllib.TOMLDecodeError,
ValueError,
UnicodeDecodeError,
ImportError,
OSError,
KeyError,
TypeError,
zipfile.BadZipFile,
_optional_exc(_load_openpyxl_exception),
# ArrowException covers ArrowIOError and ArrowCapacityError which
# do not inherit from standard exceptions; ArrowInvalid/ArrowTypeError
# already map to ValueError/TypeError but this catches the rest.
_optional_exc(_load_arrow_exception),
)
if exc is not None
)
# ---------------------------------------------------------------------------
# Text-based parsers (content: str → Any)
# ---------------------------------------------------------------------------
def _parse_container(parser: Callable[[str], Any], content: str) -> list | dict | str:
"""Parse *content* and return the result only if it is a container (list/dict).
Scalar values (strings, numbers, booleans, None) are discarded and the
original *content* string is returned instead. This prevents e.g. a JSON
file containing just ``"42"`` from silently becoming an int.
"""
parsed = parser(content)
if isinstance(parsed, (list, dict)):
return parsed
return content
def _parse_json(content: str) -> list | dict | str:
return _parse_container(json.loads, content)
def _parse_jsonl(content: str) -> Any:
lines = [json.loads(line) for line in content.splitlines() if line.strip()]
if not lines:
return content
# When every line is a dict with the same keys, convert to table format
# (header row + data rows) — consistent with CSV/TSV/Parquet/Excel output.
# Require ≥2 dicts so a single-line JSONL stays as [dict] (not a table).
if len(lines) >= 2 and all(isinstance(obj, dict) for obj in lines):
keys = list(lines[0].keys())
# Cache as tuple to avoid O(n×k) list allocations in the all() call.
keys_tuple = tuple(keys)
if keys and all(tuple(obj.keys()) == keys_tuple for obj in lines[1:]):
return [keys] + [[obj[k] for k in keys] for obj in lines]
return lines
def _parse_csv(content: str) -> Any:
return _parse_delimited(content, delimiter=",")
def _parse_tsv(content: str) -> Any:
return _parse_delimited(content, delimiter="\t")
def _parse_delimited(content: str, *, delimiter: str) -> Any:
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
# csv.reader never yields [] — blank lines yield [""]. Filter out
# rows where every cell is empty (i.e. truly blank lines).
rows = [row for row in reader if _row_has_content(row)]
if not rows:
return content
# If the declared delimiter produces only single-column rows, try
# sniffing the actual delimiter — catches misidentified files (e.g.
# a tab-delimited file with a .csv extension).
if len(rows[0]) == 1:
try:
dialect = csv.Sniffer().sniff(content[:8192])
if dialect.delimiter != delimiter:
reader = csv.reader(io.StringIO(content), dialect)
rows = [row for row in reader if _row_has_content(row)]
except csv.Error:
pass
if rows and len(rows[0]) >= 2:
return rows
return content
def _row_has_content(row: list[str]) -> bool:
"""Return True when *row* contains at least one non-empty cell.
``csv.reader`` never yields ``[]`` — truly blank lines yield ``[""]``.
This predicate filters those out consistently across the initial read
and the sniffer-fallback re-read.
"""
return any(cell for cell in row)
def _parse_yaml(content: str) -> list | dict | str:
# NOTE: YAML anchor/alias expansion can amplify input beyond the 10MB cap.
# safe_load prevents code execution; for production hardening consider
# a YAML parser with expansion limits (e.g. ruamel.yaml with max_alias_count).
if "\n---" in content or content.startswith("---\n"):
# Multi-document YAML: only the first document is parsed; the rest
# are silently ignored by yaml.safe_load. Warn so callers are aware.
logger.warning(
"Multi-document YAML detected (--- separator); "
"only the first document will be parsed."
)
return _parse_container(yaml.safe_load, content)
def _parse_toml(content: str) -> Any:
parsed = tomllib.loads(content)
# tomllib.loads always returns a dict — return it even if empty.
return parsed
_TEXT_PARSERS: dict[str, Callable[[str], Any]] = {
"json": _parse_json,
"jsonl": _parse_jsonl,
"csv": _parse_csv,
"tsv": _parse_tsv,
"yaml": _parse_yaml,
"toml": _parse_toml,
}
# ---------------------------------------------------------------------------
# Binary-based parsers (content: bytes → Any)
# ---------------------------------------------------------------------------
def _parse_parquet(content: bytes) -> list[list[Any]]:
import pandas as pd
df = pd.read_parquet(io.BytesIO(content))
return _df_to_rows(df)
def _parse_xlsx(content: bytes) -> list[list[Any]]:
import pandas as pd
# Explicitly specify openpyxl engine; the default engine varies by pandas
# version and does not support legacy .xls (which is excluded by our format map).
df = pd.read_excel(io.BytesIO(content), engine="openpyxl")
return _df_to_rows(df)
def _df_to_rows(df: Any) -> list[list[Any]]:
"""Convert a DataFrame to ``list[list[Any]]`` with a header row.
NaN values are replaced with ``None`` so the result is JSON-serializable.
Uses explicit cell-level checking because ``df.where(df.notna(), None)``
silently converts ``None`` back to ``NaN`` in float64 columns.
"""
header = df.columns.tolist()
rows = [
[None if _is_nan(cell) else cell for cell in row] for row in df.values.tolist()
]
return [header] + rows
def _is_nan(cell: Any) -> bool:
"""Check if a cell value is NaN, handling non-scalar types (lists, dicts).
``pd.isna()`` on a list/dict returns a boolean array which raises
``ValueError`` in a boolean context. Guard with a scalar check first.
"""
import pandas as pd
return bool(pd.api.types.is_scalar(cell) and pd.isna(cell))
_BINARY_PARSERS: dict[str, Callable[[bytes], Any]] = {
"parquet": _parse_parquet,
"xlsx": _parse_xlsx,
}

View File

@@ -0,0 +1,624 @@
"""Tests for file_content_parser — format inference and structured parsing."""
import io
import json
import pytest
from backend.util.file_content_parser import (
BINARY_FORMATS,
infer_format_from_uri,
parse_file_content,
)
# ---------------------------------------------------------------------------
# infer_format_from_uri
# ---------------------------------------------------------------------------
class TestInferFormat:
# --- extension-based ---
def test_json_extension(self):
assert infer_format_from_uri("/home/user/data.json") == "json"
def test_jsonl_extension(self):
assert infer_format_from_uri("/tmp/events.jsonl") == "jsonl"
def test_ndjson_extension(self):
assert infer_format_from_uri("/tmp/events.ndjson") == "jsonl"
def test_csv_extension(self):
assert infer_format_from_uri("workspace:///reports/sales.csv") == "csv"
def test_tsv_extension(self):
assert infer_format_from_uri("/home/user/data.tsv") == "tsv"
def test_yaml_extension(self):
assert infer_format_from_uri("/home/user/config.yaml") == "yaml"
def test_yml_extension(self):
assert infer_format_from_uri("/home/user/config.yml") == "yaml"
def test_toml_extension(self):
assert infer_format_from_uri("/home/user/config.toml") == "toml"
def test_parquet_extension(self):
assert infer_format_from_uri("/data/table.parquet") == "parquet"
def test_xlsx_extension(self):
assert infer_format_from_uri("/data/spreadsheet.xlsx") == "xlsx"
def test_xls_extension_returns_xls_label(self):
# Legacy .xls is mapped so callers can produce a helpful error.
assert infer_format_from_uri("/data/old_spreadsheet.xls") == "xls"
def test_case_insensitive(self):
assert infer_format_from_uri("/data/FILE.JSON") == "json"
assert infer_format_from_uri("/data/FILE.CSV") == "csv"
def test_unicode_filename(self):
assert infer_format_from_uri("/home/user/\u30c7\u30fc\u30bf.json") == "json"
assert infer_format_from_uri("/home/user/\u00e9t\u00e9.csv") == "csv"
def test_unknown_extension(self):
assert infer_format_from_uri("/home/user/readme.txt") is None
def test_no_extension(self):
assert infer_format_from_uri("workspace://abc123") is None
# --- MIME-based ---
def test_mime_json(self):
assert infer_format_from_uri("workspace://abc123#application/json") == "json"
def test_mime_csv(self):
assert infer_format_from_uri("workspace://abc123#text/csv") == "csv"
def test_mime_tsv(self):
assert (
infer_format_from_uri("workspace://abc123#text/tab-separated-values")
== "tsv"
)
def test_mime_ndjson(self):
assert (
infer_format_from_uri("workspace://abc123#application/x-ndjson") == "jsonl"
)
def test_mime_yaml(self):
assert infer_format_from_uri("workspace://abc123#application/x-yaml") == "yaml"
def test_mime_xlsx(self):
uri = "workspace://abc123#application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
assert infer_format_from_uri(uri) == "xlsx"
def test_mime_parquet(self):
assert (
infer_format_from_uri("workspace://abc123#application/vnd.apache.parquet")
== "parquet"
)
def test_unknown_mime(self):
assert infer_format_from_uri("workspace://abc123#text/plain") is None
def test_unknown_mime_falls_through_to_extension(self):
# Unknown MIME (text/plain) should fall through to extension-based detection.
assert infer_format_from_uri("workspace:///data.csv#text/plain") == "csv"
# --- MIME takes precedence over extension ---
def test_mime_overrides_extension(self):
# .txt extension but JSON MIME → json
assert infer_format_from_uri("workspace:///file.txt#application/json") == "json"
# ---------------------------------------------------------------------------
# parse_file_content — JSON
# ---------------------------------------------------------------------------
class TestParseJson:
def test_array(self):
result = parse_file_content("[1, 2, 3]", "json")
assert result == [1, 2, 3]
def test_object(self):
result = parse_file_content('{"key": "value"}', "json")
assert result == {"key": "value"}
def test_nested(self):
content = json.dumps({"rows": [[1, 2], [3, 4]]})
result = parse_file_content(content, "json")
assert result == {"rows": [[1, 2], [3, 4]]}
def test_scalar_string_stays_as_string(self):
result = parse_file_content('"hello"', "json")
assert result == '"hello"' # original content, not parsed
def test_scalar_number_stays_as_string(self):
result = parse_file_content("42", "json")
assert result == "42"
def test_scalar_boolean_stays_as_string(self):
result = parse_file_content("true", "json")
assert result == "true"
def test_null_stays_as_string(self):
result = parse_file_content("null", "json")
assert result == "null"
def test_invalid_json_fallback(self):
content = "not json at all"
result = parse_file_content(content, "json")
assert result == content
def test_empty_string_fallback(self):
result = parse_file_content("", "json")
assert result == ""
def test_bytes_input_decoded(self):
result = parse_file_content(b"[1, 2, 3]", "json")
assert result == [1, 2, 3]
# ---------------------------------------------------------------------------
# parse_file_content — JSONL
# ---------------------------------------------------------------------------
class TestParseJsonl:
def test_tabular_uniform_dicts_to_table_format(self):
"""JSONL with uniform dict keys → table format (header + rows),
consistent with CSV/TSV/Parquet/Excel output."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":"yellow"}\n{"name":"cherry","color":"red"}'
result = parse_file_content(content, "jsonl")
assert result == [
["name", "color"],
["apple", "red"],
["banana", "yellow"],
["cherry", "red"],
]
def test_tabular_single_key_dicts(self):
"""JSONL with single-key uniform dicts → table format."""
content = '{"a": 1}\n{"a": 2}\n{"a": 3}'
result = parse_file_content(content, "jsonl")
assert result == [["a"], [1], [2], [3]]
def test_tabular_blank_lines_skipped(self):
content = '{"a": 1}\n\n{"a": 2}\n'
result = parse_file_content(content, "jsonl")
assert result == [["a"], [1], [2]]
def test_heterogeneous_dicts_stay_as_list(self):
"""JSONL with different keys across objects → list of dicts (no table)."""
content = '{"name":"apple"}\n{"color":"red"}\n{"size":3}'
result = parse_file_content(content, "jsonl")
assert result == [{"name": "apple"}, {"color": "red"}, {"size": 3}]
def test_partially_overlapping_keys_stay_as_list(self):
"""JSONL dicts with partially overlapping keys → list of dicts."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","size":"medium"}'
result = parse_file_content(content, "jsonl")
assert result == [
{"name": "apple", "color": "red"},
{"name": "banana", "size": "medium"},
]
def test_mixed_types_stay_as_list(self):
"""JSONL with non-dict lines → list of parsed values (no table)."""
content = '1\n"hello"\n[1,2]\n'
result = parse_file_content(content, "jsonl")
assert result == [1, "hello", [1, 2]]
def test_mixed_dicts_and_non_dicts_stay_as_list(self):
"""JSONL mixing dicts and non-dicts → list of parsed values."""
content = '{"a": 1}\n42\n{"b": 2}'
result = parse_file_content(content, "jsonl")
assert result == [{"a": 1}, 42, {"b": 2}]
def test_tabular_preserves_key_order(self):
"""Table header should follow the key order of the first object."""
content = '{"z": 1, "a": 2}\n{"z": 3, "a": 4}'
result = parse_file_content(content, "jsonl")
assert result[0] == ["z", "a"] # order from first object
assert result[1] == [1, 2]
assert result[2] == [3, 4]
def test_single_dict_stays_as_list(self):
"""Single-line JSONL with one dict → [dict], NOT a table.
Tabular detection requires ≥2 dicts to avoid vacuously true all()."""
content = '{"a": 1, "b": 2}'
result = parse_file_content(content, "jsonl")
assert result == [{"a": 1, "b": 2}]
def test_tabular_with_none_values(self):
"""Uniform keys but some null values → table with None cells."""
content = '{"name":"apple","color":"red"}\n{"name":"banana","color":null}'
result = parse_file_content(content, "jsonl")
assert result == [
["name", "color"],
["apple", "red"],
["banana", None],
]
def test_empty_file_fallback(self):
result = parse_file_content("", "jsonl")
assert result == ""
def test_all_blank_lines_fallback(self):
result = parse_file_content("\n\n\n", "jsonl")
assert result == "\n\n\n"
def test_invalid_line_fallback(self):
content = '{"a": 1}\nnot json\n'
result = parse_file_content(content, "jsonl")
assert result == content # fallback
# ---------------------------------------------------------------------------
# parse_file_content — CSV
# ---------------------------------------------------------------------------
class TestParseCsv:
def test_basic(self):
content = "Name,Score\nAlice,90\nBob,85"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_quoted_fields(self):
content = 'Name,Bio\nAlice,"Loves, commas"\nBob,Simple'
result = parse_file_content(content, "csv")
assert result[1] == ["Alice", "Loves, commas"]
def test_single_column_fallback(self):
# Only 1 column — not tabular enough.
content = "Name\nAlice\nBob"
result = parse_file_content(content, "csv")
assert result == content
def test_empty_rows_skipped(self):
content = "A,B\n\n1,2\n\n3,4"
result = parse_file_content(content, "csv")
assert result == [["A", "B"], ["1", "2"], ["3", "4"]]
def test_empty_file_fallback(self):
result = parse_file_content("", "csv")
assert result == ""
def test_utf8_bom(self):
"""CSV with a UTF-8 BOM should parse correctly (BOM stripped by decode)."""
bom = "\ufeff"
content = bom + "Name,Score\nAlice,90\nBob,85"
result = parse_file_content(content, "csv")
# The BOM may be part of the first header cell; ensure rows are still parsed.
assert len(result) == 3
assert result[1] == ["Alice", "90"]
assert result[2] == ["Bob", "85"]
# ---------------------------------------------------------------------------
# parse_file_content — TSV
# ---------------------------------------------------------------------------
class TestParseTsv:
def test_basic(self):
content = "Name\tScore\nAlice\t90\nBob\t85"
result = parse_file_content(content, "tsv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_single_column_fallback(self):
content = "Name\nAlice\nBob"
result = parse_file_content(content, "tsv")
assert result == content
# ---------------------------------------------------------------------------
# parse_file_content — YAML
# ---------------------------------------------------------------------------
class TestParseYaml:
def test_list(self):
content = "- apple\n- banana\n- cherry"
result = parse_file_content(content, "yaml")
assert result == ["apple", "banana", "cherry"]
def test_dict(self):
content = "name: Alice\nage: 30"
result = parse_file_content(content, "yaml")
assert result == {"name": "Alice", "age": 30}
def test_nested(self):
content = "users:\n - name: Alice\n - name: Bob"
result = parse_file_content(content, "yaml")
assert result == {"users": [{"name": "Alice"}, {"name": "Bob"}]}
def test_scalar_stays_as_string(self):
result = parse_file_content("hello world", "yaml")
assert result == "hello world"
def test_invalid_yaml_fallback(self):
content = ":\n :\n invalid: - -"
result = parse_file_content(content, "yaml")
# Malformed YAML should fall back to the original string, not raise.
assert result == content
# ---------------------------------------------------------------------------
# parse_file_content — TOML
# ---------------------------------------------------------------------------
class TestParseToml:
def test_basic(self):
content = '[server]\nhost = "localhost"\nport = 8080'
result = parse_file_content(content, "toml")
assert result == {"server": {"host": "localhost", "port": 8080}}
def test_flat(self):
content = 'name = "test"\ncount = 42'
result = parse_file_content(content, "toml")
assert result == {"name": "test", "count": 42}
def test_empty_string_returns_empty_dict(self):
result = parse_file_content("", "toml")
assert result == {}
def test_invalid_toml_fallback(self):
result = parse_file_content("not = [valid toml", "toml")
assert result == "not = [valid toml"
# ---------------------------------------------------------------------------
# parse_file_content — Parquet (binary)
# ---------------------------------------------------------------------------
try:
import pyarrow as _pa # noqa: F401 # pyright: ignore[reportMissingImports]
_has_pyarrow = True
except ImportError:
_has_pyarrow = False
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
class TestParseParquet:
@pytest.fixture
def parquet_bytes(self) -> bytes:
import pandas as pd
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
return buf.getvalue()
def test_basic(self, parquet_bytes: bytes):
result = parse_file_content(parquet_bytes, "parquet")
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
def test_string_input_fallback(self):
# Parquet is binary — string input can't be parsed.
result = parse_file_content("not parquet", "parquet")
assert result == "not parquet"
def test_invalid_bytes_fallback(self):
result = parse_file_content(b"not parquet bytes", "parquet")
assert result == b"not parquet bytes"
def test_empty_bytes_fallback(self):
"""Empty binary input should return the empty bytes, not crash."""
result = parse_file_content(b"", "parquet")
assert result == b""
def test_nan_replaced_with_none(self):
"""NaN values in Parquet must become None for JSON serializability."""
import math
import pandas as pd
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", None, "z"]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
result = parse_file_content(buf.getvalue(), "parquet")
# Row with NaN in float col → None
assert result[2][0] is None # float NaN → None
assert result[2][1] is None # str None → None
# Ensure no NaN leaks
for row in result[1:]:
for cell in row:
if isinstance(cell, float):
assert not math.isnan(cell), f"NaN leaked: {row}"
# ---------------------------------------------------------------------------
# parse_file_content — Excel (binary)
# ---------------------------------------------------------------------------
class TestParseExcel:
@pytest.fixture
def xlsx_bytes(self) -> bytes:
import pandas as pd
df = pd.DataFrame({"Name": ["Alice", "Bob"], "Score": [90, 85]})
buf = io.BytesIO()
df.to_excel(buf, index=False) # type: ignore[arg-type] # BytesIO is a valid target
return buf.getvalue()
def test_basic(self, xlsx_bytes: bytes):
result = parse_file_content(xlsx_bytes, "xlsx")
assert result == [["Name", "Score"], ["Alice", 90], ["Bob", 85]]
def test_string_input_fallback(self):
result = parse_file_content("not xlsx", "xlsx")
assert result == "not xlsx"
def test_invalid_bytes_fallback(self):
result = parse_file_content(b"not xlsx bytes", "xlsx")
assert result == b"not xlsx bytes"
def test_empty_bytes_fallback(self):
"""Empty binary input should return the empty bytes, not crash."""
result = parse_file_content(b"", "xlsx")
assert result == b""
def test_nan_replaced_with_none(self):
"""NaN values in float columns must become None for JSON serializability."""
import math
import pandas as pd
df = pd.DataFrame({"A": [1.0, float("nan"), 3.0], "B": ["x", "y", None]})
buf = io.BytesIO()
df.to_excel(buf, index=False) # type: ignore[arg-type]
result = parse_file_content(buf.getvalue(), "xlsx")
# Row with NaN in float col → None, not float('nan')
assert result[2][0] is None # float NaN → None
assert result[3][1] is None # str None → None
# Ensure no NaN leaks
for row in result[1:]: # skip header
for cell in row:
if isinstance(cell, float):
assert not math.isnan(cell), f"NaN leaked: {row}"
# ---------------------------------------------------------------------------
# parse_file_content — unknown format / fallback
# ---------------------------------------------------------------------------
class TestFallback:
def test_unknown_format_returns_content(self):
result = parse_file_content("hello world", "xml")
assert result == "hello world"
def test_none_format_returns_content(self):
# Shouldn't normally be called with unrecognised format, but must not crash.
result = parse_file_content("hello", "unknown_format")
assert result == "hello"
# ---------------------------------------------------------------------------
# BINARY_FORMATS
# ---------------------------------------------------------------------------
class TestBinaryFormats:
def test_parquet_is_binary(self):
assert "parquet" in BINARY_FORMATS
def test_xlsx_is_binary(self):
assert "xlsx" in BINARY_FORMATS
def test_text_formats_not_binary(self):
for fmt in ("json", "jsonl", "csv", "tsv", "yaml", "toml"):
assert fmt not in BINARY_FORMATS
# ---------------------------------------------------------------------------
# MIME mapping
# ---------------------------------------------------------------------------
class TestMimeMapping:
def test_application_yaml(self):
assert infer_format_from_uri("workspace://abc123#application/yaml") == "yaml"
# ---------------------------------------------------------------------------
# CSV sniffer fallback
# ---------------------------------------------------------------------------
class TestCsvSnifferFallback:
def test_tab_delimited_with_csv_format(self):
"""Tab-delimited content parsed as csv should use sniffer fallback."""
content = "Name\tScore\nAlice\t90\nBob\t85"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"], ["Alice", "90"], ["Bob", "85"]]
def test_sniffer_failure_returns_content(self):
"""When sniffer fails, single-column falls back to raw content."""
content = "Name\nAlice\nBob"
result = parse_file_content(content, "csv")
assert result == content
# ---------------------------------------------------------------------------
# OpenpyxlInvalidFile fallback
# ---------------------------------------------------------------------------
class TestOpenpyxlFallback:
def test_invalid_xlsx_non_strict(self):
"""Invalid xlsx bytes should fall back gracefully in non-strict mode."""
result = parse_file_content(b"not xlsx bytes", "xlsx")
assert result == b"not xlsx bytes"
# ---------------------------------------------------------------------------
# Header-only CSV
# ---------------------------------------------------------------------------
class TestHeaderOnlyCsv:
def test_header_only_csv_returns_header_row(self):
"""CSV with only a header row (no data rows) should return [[header]]."""
content = "Name,Score"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"]]
def test_header_only_csv_with_trailing_newline(self):
content = "Name,Score\n"
result = parse_file_content(content, "csv")
assert result == [["Name", "Score"]]
# ---------------------------------------------------------------------------
# Binary format + line range (line range ignored for binary formats)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not _has_pyarrow, reason="pyarrow not installed")
class TestBinaryFormatLineRange:
def test_parquet_ignores_line_range(self):
"""Binary formats should parse the full file regardless of line range.
Line ranges are meaningless for binary formats (parquet/xlsx) — the
caller (file_ref._expand_bare_ref) passes raw bytes and the parser
should return the complete structured data.
"""
import pandas as pd
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
buf = io.BytesIO()
df.to_parquet(buf, index=False)
# parse_file_content itself doesn't take a line range — this tests
# that the full content is parsed even though the bytes could have
# been truncated upstream (it's not, by design).
result = parse_file_content(buf.getvalue(), "parquet")
assert result == [["A", "B"], [1, 4], [2, 5], [3, 6]]
# ---------------------------------------------------------------------------
# Legacy .xls UX
# ---------------------------------------------------------------------------
class TestXlsFallback:
def test_xls_returns_helpful_error_string(self):
"""Uploading a .xls file should produce a helpful error, not garbled binary."""
result = parse_file_content(b"\xd0\xcf\x11\xe0garbled", "xls")
assert isinstance(result, str)
assert ".xlsx" in result
assert "not supported" in result.lower()
def test_xls_with_string_content(self):
result = parse_file_content("some text", "xls")
assert isinstance(result, str)
assert ".xlsx" in result

View File

@@ -8,7 +8,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.execution import ExecutionContext
from backend.util.file import store_media_file
from backend.util.file import (
is_media_file_ref,
parse_data_uri,
resolve_media_content,
store_media_file,
)
from backend.util.type import MediaFileType
@@ -344,3 +349,162 @@ class TestFileCloudIntegration:
execution_context=make_test_context(graph_exec_id=graph_exec_id),
return_format="for_local_processing",
)
# ---------------------------------------------------------------------------
# is_media_file_ref
# ---------------------------------------------------------------------------
class TestIsMediaFileRef:
def test_data_uri(self):
assert is_media_file_ref("data:image/png;base64,iVBORw0KGg==") is True
def test_workspace_uri(self):
assert is_media_file_ref("workspace://abc123") is True
def test_workspace_uri_with_mime(self):
assert is_media_file_ref("workspace://abc123#image/png") is True
def test_http_url(self):
assert is_media_file_ref("http://example.com/image.png") is True
def test_https_url(self):
assert is_media_file_ref("https://example.com/image.png") is True
def test_plain_text(self):
assert is_media_file_ref("print('hello')") is False
def test_local_path(self):
assert is_media_file_ref("/tmp/file.txt") is False
def test_empty_string(self):
assert is_media_file_ref("") is False
def test_filename(self):
assert is_media_file_ref("image.png") is False
# ---------------------------------------------------------------------------
# parse_data_uri
# ---------------------------------------------------------------------------
class TestParseDataUri:
def test_valid_png(self):
result = parse_data_uri("data:image/png;base64,iVBORw0KGg==")
assert result is not None
mime, payload = result
assert mime == "image/png"
assert payload == "iVBORw0KGg=="
def test_valid_text(self):
result = parse_data_uri("data:text/plain;base64,SGVsbG8=")
assert result is not None
assert result[0] == "text/plain"
assert result[1] == "SGVsbG8="
def test_mime_case_normalized(self):
result = parse_data_uri("data:IMAGE/PNG;base64,abc")
assert result is not None
assert result[0] == "image/png"
def test_not_data_uri(self):
assert parse_data_uri("workspace://abc123") is None
def test_plain_text(self):
assert parse_data_uri("hello world") is None
def test_missing_base64(self):
assert parse_data_uri("data:image/png;utf-8,abc") is None
def test_empty_payload(self):
result = parse_data_uri("data:image/png;base64,")
assert result is not None
assert result[1] == ""
# ---------------------------------------------------------------------------
# resolve_media_content
# ---------------------------------------------------------------------------
class TestResolveMediaContent:
@pytest.mark.asyncio
async def test_plain_text_passthrough(self):
"""Plain text content (not a media ref) passes through unchanged."""
ctx = make_test_context()
result = await resolve_media_content(
MediaFileType("print('hello')"),
ctx,
return_format="for_external_api",
)
assert result == "print('hello')"
@pytest.mark.asyncio
async def test_empty_string_passthrough(self):
"""Empty string passes through unchanged."""
ctx = make_test_context()
result = await resolve_media_content(
MediaFileType(""),
ctx,
return_format="for_external_api",
)
assert result == ""
@pytest.mark.asyncio
async def test_media_ref_delegates_to_store(self):
"""Media references are resolved via store_media_file."""
ctx = make_test_context()
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
) as mock_store:
result = await resolve_media_content(
MediaFileType("workspace://img123"),
ctx,
return_format="for_external_api",
)
assert result == "data:image/png;base64,abc"
mock_store.assert_called_once_with(
MediaFileType("workspace://img123"),
ctx,
return_format="for_external_api",
)
@pytest.mark.asyncio
async def test_data_uri_delegates_to_store(self):
"""Data URIs are also resolved via store_media_file."""
ctx = make_test_context()
data_uri = "data:image/png;base64,iVBORw0KGg=="
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType(data_uri)),
) as mock_store:
result = await resolve_media_content(
MediaFileType(data_uri),
ctx,
return_format="for_external_api",
)
assert result == data_uri
mock_store.assert_called_once()
@pytest.mark.asyncio
async def test_https_url_delegates_to_store(self):
"""HTTPS URLs are resolved via store_media_file."""
ctx = make_test_context()
with patch(
"backend.util.file.store_media_file",
new=AsyncMock(return_value=MediaFileType("data:image/png;base64,abc")),
) as mock_store:
result = await resolve_media_content(
MediaFileType("https://example.com/image.png"),
ctx,
return_format="for_local_processing",
)
assert result == "data:image/png;base64,abc"
mock_store.assert_called_once_with(
MediaFileType("https://example.com/image.png"),
ctx,
return_format="for_local_processing",
)

View File

@@ -1360,6 +1360,18 @@ files = [
dnspython = ">=2.0.0"
idna = ">=2.0.0"
[[package]]
name = "et-xmlfile"
version = "2.0.0"
description = "An implementation of lxml.xmlfile for the standard library"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"},
{file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"},
]
[[package]]
name = "exa-py"
version = "1.16.1"
@@ -4228,6 +4240,21 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<16)"]
voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"]
[[package]]
name = "openpyxl"
version = "3.1.5"
description = "A Python library to read/write Excel 2010 xlsx/xlsm files"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"},
{file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"},
]
[package.dependencies]
et-xmlfile = "*"
[[package]]
name = "opentelemetry-api"
version = "1.39.1"
@@ -5430,6 +5457,66 @@ files = [
{file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"},
]
[[package]]
name = "pyarrow"
version = "23.0.1"
description = "Python library for Apache Arrow"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3fab8f82571844eb3c460f90a75583801d14ca0cc32b1acc8c361650e006fd56"},
{file = "pyarrow-23.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:3f91c038b95f71ddfc865f11d5876c42f343b4495535bd262c7b321b0b94507c"},
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d0744403adabef53c985a7f8a082b502a368510c40d184df349a0a8754533258"},
{file = "pyarrow-23.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:c33b5bf406284fd0bba436ed6f6c3ebe8e311722b441d89397c54f871c6863a2"},
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ddf743e82f69dcd6dbbcb63628895d7161e04e56794ef80550ac6f3315eeb1d5"},
{file = "pyarrow-23.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e052a211c5ac9848ae15d5ec875ed0943c0221e2fcfe69eee80b604b4e703222"},
{file = "pyarrow-23.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5abde149bb3ce524782d838eb67ac095cd3fd6090eba051130589793f1a7f76d"},
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb"},
{file = "pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350"},
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd"},
{file = "pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9"},
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701"},
{file = "pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78"},
{file = "pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919"},
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f"},
{file = "pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7"},
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9"},
{file = "pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05"},
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67"},
{file = "pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730"},
{file = "pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0"},
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6b8fda694640b00e8af3c824f99f789e836720aa8c9379fb435d4c4953a756b8"},
{file = "pyarrow-23.0.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:8ff51b1addc469b9444b7c6f3548e19dc931b172ab234e995a60aea9f6e6025f"},
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:71c5be5cbf1e1cb6169d2a0980850bccb558ddc9b747b6206435313c47c37677"},
{file = "pyarrow-23.0.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9b6f4f17b43bc39d56fec96e53fe89d94bac3eb134137964371b45352d40d0c2"},
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9fc13fc6c403d1337acab46a2c4346ca6c9dec5780c3c697cf8abfd5e19b6b37"},
{file = "pyarrow-23.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c16ed4f53247fa3ffb12a14d236de4213a4415d127fe9cebed33d51671113e2"},
{file = "pyarrow-23.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:cecfb12ef629cf6be0b1887f9f86463b0dd3dc3195ae6224e74006be4736035a"},
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:29f7f7419a0e30264ea261fdc0e5fe63ce5a6095003db2945d7cd78df391a7e1"},
{file = "pyarrow-23.0.1-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:33d648dc25b51fd8055c19e4261e813dfc4d2427f068bcecc8b53d01b81b0500"},
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd395abf8f91c673dd3589cadc8cc1ee4e8674fa61b2e923c8dd215d9c7d1f41"},
{file = "pyarrow-23.0.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:00be9576d970c31defb5c32eb72ef585bf600ef6d0a82d5eccaae96639cf9d07"},
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c2139549494445609f35a5cda4eb94e2c9e4d704ce60a095b342f82460c73a83"},
{file = "pyarrow-23.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:7044b442f184d84e2351e5084600f0d7343d6117aabcbc1ac78eb1ae11eb4125"},
{file = "pyarrow-23.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a35581e856a2fafa12f3f54fce4331862b1cfb0bef5758347a858a4aa9d6bae8"},
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:5df1161da23636a70838099d4aaa65142777185cc0cdba4037a18cee7d8db9ca"},
{file = "pyarrow-23.0.1-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:fa8e51cb04b9f8c9c5ace6bab63af9a1f88d35c0d6cbf53e8c17c098552285e1"},
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b95a3994f015be13c63148fef8832e8a23938128c185ee951c98908a696e0eb"},
{file = "pyarrow-23.0.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:4982d71350b1a6e5cfe1af742c53dfb759b11ce14141870d05d9e540d13bc5d1"},
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c250248f1fe266db627921c89b47b7c06fee0489ad95b04d50353537d74d6886"},
{file = "pyarrow-23.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5f4763b83c11c16e5f4c15601ba6dfa849e20723b46aa2617cb4bffe8768479f"},
{file = "pyarrow-23.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:3a4c85ef66c134161987c17b147d6bffdca4566f9a4c1d81a0a01cdf08414ea5"},
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:17cd28e906c18af486a499422740298c52d7c6795344ea5002a7720b4eadf16d"},
{file = "pyarrow-23.0.1-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:76e823d0e86b4fb5e1cf4a58d293036e678b5a4b03539be933d3b31f9406859f"},
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:a62e1899e3078bf65943078b3ad2a6ddcacf2373bc06379aac61b1e548a75814"},
{file = "pyarrow-23.0.1-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:df088e8f640c9fae3b1f495b3c64755c4e719091caf250f3a74d095ddf3c836d"},
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:46718a220d64677c93bc243af1d44b55998255427588e400677d7192671845c7"},
{file = "pyarrow-23.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a09f3876e87f48bc2f13583ab551f0379e5dfb83210391e68ace404181a20690"},
{file = "pyarrow-23.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:527e8d899f14bd15b740cd5a54ad56b7f98044955373a17179d5956ddb93d9ce"},
{file = "pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019"},
]
[[package]]
name = "pyasn1"
version = "0.6.2"
@@ -8882,4 +8969,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "4e4365721cd3b68c58c237353b74adae1c64233fd4446904c335f23eb866fdca"
content-hash = "86dab25684dd46e635a33bd33281a926e5626a874ecc048c34389fecf34a87d8"

View File

@@ -92,6 +92,8 @@ gravitas-md2gdocs = "^0.1.0"
posthog = "^7.6.0"
fpdf2 = "^2.8.6"
langsmith = "^0.7.7"
openpyxl = "^3.1.5"
pyarrow = "^23.0.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"

View File

@@ -3,6 +3,7 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { ExclamationMarkIcon } from "@phosphor-icons/react";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useState } from "react";
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
import {
@@ -129,6 +130,8 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
case "tool-search_docs":
case "tool-get_doc_page":
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
case "tool-connect_integration":
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
case "tool-run_block":
case "tool-continue_run_block":
return <RunBlockTool key={key} part={part as ToolUIPart} />;

View File

@@ -0,0 +1,104 @@
"use client";
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
import type { ToolUIPart } from "ai";
import { useState } from "react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
type Props = {
part: ToolUIPart;
};
function parseJson(raw: unknown): unknown {
if (typeof raw === "string") {
try {
return JSON.parse(raw);
} catch {
return null;
}
}
return raw;
}
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
return parsed as SetupRequirementsResponse;
}
return null;
}
function parseError(raw: unknown): string | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "message" in parsed) {
return String((parsed as { message: unknown }).message);
}
return null;
}
export function ConnectIntegrationTool({ part }: Props) {
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
const [isDismissed, setIsDismissed] = useState(false);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
const output =
part.state === "output-available"
? parseOutput((part as { output?: unknown }).output)
: null;
const errorMessage = isError
? (parseError((part as { output?: unknown }).output) ??
"Failed to connect integration")
: null;
const rawProvider =
(part as { input?: { provider?: string } }).input?.provider ?? "";
const providerName =
output?.setup_info?.agent_name ??
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
// prevent runaway text in the DOM.
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
const label = isStreaming
? `Connecting ${providerName}`
: isError
? `Failed to connect ${providerName}`
: output
? `Connect ${output.setup_info?.agent_name ?? providerName}`
: `Connect ${providerName}`;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<MorphingTextAnimation
text={label}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isError && errorMessage && (
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
)}
{output && (
<div className="mt-2">
{isDismissed ? (
<ContentMessage>Connected. Continuing</ContentMessage>
) : (
<SetupRequirementsCard
output={output}
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
retryInstruction="I've connected my account. Please continue."
onComplete={() => setIsDismissed(true)}
/>
)}
</div>
)}
</div>
);
}

View File

@@ -23,12 +23,16 @@ interface Props {
/** Override the label shown above the credentials section.
* Defaults to "Credentials". */
credentialsLabel?: string;
/** Called after Proceed is clicked so the parent can persist the dismissed state
* across remounts (avoids re-enabling the Proceed button on remount). */
onComplete?: () => void;
}
export function SetupRequirementsCard({
output,
retryInstruction,
credentialsLabel,
onComplete,
}: Props) {
const { onSend } = useCopilotChatActions();
@@ -68,13 +72,17 @@ export function SetupRequirementsCard({
return v !== undefined && v !== null && v !== "";
});
if (hasSent) {
return <ContentMessage>Connected. Continuing</ContentMessage>;
}
const canRun =
!hasSent &&
(!needsCredentials || isAllCredentialsComplete) &&
(!needsInputs || isAllInputsComplete);
function handleRun() {
setHasSent(true);
onComplete?.();
const parts: string[] = [];
if (needsCredentials) {

View File

@@ -0,0 +1,40 @@
"use client";
import { ArrowRight, Lightning } from "@phosphor-icons/react";
import NextLink from "next/link";
import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import { useJumpBackIn } from "./useJumpBackIn";
export function JumpBackIn() {
const { agent, isLoading } = useJumpBackIn();
if (isLoading || !agent) {
return null;
}
return (
<div className="flex items-center justify-between rounded-large border border-zinc-200 bg-gradient-to-r from-zinc-50 to-white px-5 py-4">
<div className="flex items-center gap-3">
<div className="flex h-9 w-9 items-center justify-center rounded-full bg-zinc-900">
<Lightning size={18} weight="fill" className="text-white" />
</div>
<div className="flex flex-col">
<Text variant="small" className="text-zinc-500">
Continue where you left off
</Text>
<Text variant="body-medium" className="text-zinc-900">
{agent.name}
</Text>
</div>
</div>
<NextLink href={`/library/agents/${agent.id}`}>
<Button variant="primary" size="small" className="gap-1.5">
Jump Back In
<ArrowRight size={16} />
</Button>
</NextLink>
</div>
);
}

View File

@@ -0,0 +1,28 @@
"use client";
import { useGetV2ListLibraryAgents } from "@/app/api/__generated__/endpoints/library/library";
import { okData } from "@/app/api/helpers";
export function useJumpBackIn() {
const { data, isLoading } = useGetV2ListLibraryAgents(
{
page: 1,
page_size: 1,
sort_by: "updatedAt",
},
{
query: { select: okData },
},
);
// The API doesn't include execution data by default (include_executions is
// internal to the backend), so recent_executions is always empty here.
// We use the most recently updated agent as the "jump back in" candidate
// instead — updatedAt is the best available proxy for recent activity.
const agent = data?.agents[0] ?? null;
return {
agent,
isLoading,
};
}

View File

@@ -2,6 +2,7 @@
import { useEffect, useState, useCallback } from "react";
import { HeartIcon, ListIcon } from "@phosphor-icons/react";
import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn";
import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader";
import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList";
import { Tab } from "./components/LibraryTabs/LibraryTabs";
@@ -38,6 +39,7 @@ export default function LibraryPage() {
onAnimationComplete={handleFavoriteAnimationComplete}
>
<main className="pt-160 container min-h-screen space-y-4 pb-20 pt-16 sm:px-8 md:px-12">
<JumpBackIn />
<LibraryActionHeader setSearchTerm={setSearchTerm} />
<LibraryAgentList
searchTerm={searchTerm}

View File

@@ -125,9 +125,9 @@ export function useCredentialsInput({
if (hasAttemptedAutoSelect.current) return;
hasAttemptedAutoSelect.current = true;
// Auto-select if exactly one credential matches.
// For optional fields with multiple options, let the user choose.
if (isOptional && savedCreds.length > 1) return;
// Auto-select only when there is exactly one saved credential.
// With multiple options the user must choose — regardless of optional/required.
if (savedCreds.length > 1) return;
const cred = savedCreds[0];
onSelectCredential({