Compare commits

..

23 Commits

Author SHA1 Message Date
Zamil Majdy
88eaab2baa Merge remote-tracking branch 'origin/dev' into feat/github-cli-copilot 2026-03-17 06:17:03 +07:00
Zamil Majdy
4b0a445635 fix(copilot): remove implicit gh auth setup-git from sandbox creation
Remove the automatic GitHub credential helper configuration that ran on
every E2B sandbox connect/reconnect. This addressed a review concern
about implicitly giving AutoPilot full GitHub access without user
awareness or opt-in.

The bash_exec tool already injects GH_TOKEN/GITHUB_TOKEN per-command
for users who have connected their account via connect_integration,
which is the explicit opt-in path.
2026-03-17 00:36:51 +07:00
Zamil Majdy
36312d2c6e fix(backend/copilot): bust cache on OAuth refresh + persist dismissed state
- creds_manager: call _bust_copilot_cache after refresh_if_needed
  persists the refreshed token so the copilot cache doesn't serve
  a stale access token after silent refresh
- ConnectIntegrationTool: lift isDismissed state to parent so
  SetupRequirementsCard remounts don't re-enable the Proceed button;
  onComplete callback propagates the dismissed signal up
2026-03-16 17:10:18 +07:00
Zamil Majdy
d6d3b8d710 fix(copilot): address coderabbitai major issues — scope token description to E2B, guard cache-bust hook
- connect_integration.py: clarify that GH_TOKEN is injected per-command in
  E2B/cloud only; note that bubblewrap isolates network so retry won't work
- creds_manager._bust_copilot_cache: wrap _on_creds_changed in try/except so
  a failing hook doesn't turn successful create/update/delete into a 500
2026-03-16 15:52:40 +07:00
Zamil Majdy
17d8d0bf05 fix(backend/copilot): run gh auth setup-git once on sandbox connect/reconnect
Move git credential helper setup out of bash_exec (where it ran on
every command) and into _setup_e2b so it runs exactly once per sandbox
connect or reconnect. Non-fatal: logged at debug level on failure.
2026-03-16 15:45:18 +07:00
Zamil Majdy
5a2ab65f41 fix(backend/copilot): run gh auth setup-git once per sandbox session
Use grep to skip re-running if the credential helper is already
configured in ~/.gitconfig — only pays the cost on first command.
Agent can still call it manually if GH_TOKEN changes mid-session.
2026-03-16 15:42:54 +07:00
Zamil Majdy
81a318de3e feat(backend/copilot): improve GitHub OAuth UX and git auth
- Dynamic OAuth scopes: connect_integration tool now accepts a `scopes`
  param so the agent can request exactly the access it needs (e.g.
  `["repo", "read:org"]`); GitHub defaults to `["repo"]` so git
  push/pull works out of the box instead of public-data-only
- Lazy git auth: prepend `gh auth setup-git` on every E2B bash_exec
  when GH_TOKEN is present — git HTTPS clone/push/pull now work
  automatically without the agent needing to set this up manually
- Prefer broadest-scoped OAuth2 credential: sort repo-scoped tokens
  first so a stale public-data token is never picked over a full one
- Collapse SetupRequirementsCard to "Connected. Continuing…" after
  Proceed is clicked instead of leaving the full card visible
- Fix credential auto-select: don't silently pick the first token when
  multiple credentials exist — let the user choose via the dropdown
2026-03-16 15:26:14 +07:00
Zamil Majdy
62c8e8634b fix(copilot): patch _manager singleton directly in tests instead of class constructor
The module-level _manager singleton is created at import time, so patching
IntegrationCredentialsManager after import has no effect. Patch the _manager
attribute directly so get_provider_token uses the mock.
2026-03-16 06:32:36 +07:00
Zamil Majdy
b91c959cd9 fix(copilot): address remaining review findings
- creds_manager: fix TOCTOU in delete() — move get_creds_by_id inside the lock
- creds_manager: replace lazy import in _bust_copilot_cache with a
  register_creds_changed_hook() callback so creds_manager has no runtime
  dependency on the copilot module
- integration_creds: register invalidate_user_provider_cache at module import
  via register_creds_changed_hook() — eliminates the circular-import risk
- integration_creds: add module-level _manager singleton (avoids re-instantiating
  IntegrationCredentialsManager on every cache miss)
- integration_creds: document TTLCache asyncio-only thread-safety assumption
- connect_integration: defer GITHUB_OAUTH_IS_CONFIGURED evaluation to runtime
  with an lru_cache'd helper; importing the module no longer triggers Secrets()
- connect_integration: type missing_credentials dict with _CredentialEntry TypedDict
- connect_integration: cap reason field at 500 chars; add maxLength to JSON schema
- bash_exec: use 'user_id is not None' instead of truthy check
- connect_integration_test: add test for unauthenticated caller (requires_auth guard)
- bash_exec_test: add E2B path tests — token injected when user_id set, skipped
  when user_id is None
- ConnectIntegrationTool.tsx: sanitize LLM-controlled providerName fallback
  (trim + slice to 64 chars)
2026-03-16 06:21:44 +07:00
Zamil Majdy
5b95a2a1ef refactor(copilot): strongly type _PROVIDER_INFO with TypedDict
Replace dict[str, Any] with a _ProviderInfo TypedDict for provider
metadata entries, eliminating key/type drift as new providers are added.
2026-03-16 06:04:02 +07:00
Zamil Majdy
9c2a601167 refactor(copilot): simplify cache with cachetools.TTLCache, fix prompt wording
- Replace manual dict+sentinel cache with two TTLCache instances:
  _token_cache (5min TTL) and _null_cache (60s TTL)
- Remove _cache_set helper and _NO_TOKEN sentinel — TTLCache handles
  expiry and LRU eviction natively
- Update tests to use _token_cache/_null_cache directly; add TTL constant test
- Change _E2B_TOOL_NOTES from "GH_TOKEN is set" to "gh is pre-authenticated"
  so the AI doesn't attempt to read the env var directly
2026-03-16 00:16:26 +07:00
Zamil Majdy
b98e37bf23 refactor(copilot): DRY cache-bust helper, fast eviction test, unified JSON parse
Backend:
- Extract _bust_copilot_cache() in creds_manager.py; create/update/delete
  now each call it once instead of repeating the try/except ImportError block
- test_evicts_oldest_when_full: patch _CACHE_MAX_SIZE to 3 to avoid
  allocating 10 000 entries in CI; remove now-unused _CACHE_MAX_SIZE import

Frontend:
- Extract parseJson() helper shared by parseOutput and parseError in
  ConnectIntegrationTool.tsx, eliminating duplicated try/catch logic
2026-03-16 00:01:10 +07:00
Zamil Majdy
fec8924361 fix(copilot): bust token cache on update/delete, tighten except clause
- creds_manager.create/update/delete now all call invalidate_user_provider_cache
  after mutating credentials, so the next bash_exec always picks up the
  current state without waiting for TTL to expire
- Change broad `except Exception` to `except ImportError` in all three methods
  so real bugs inside invalidate_user_provider_cache are not silently swallowed
- delete() reads the provider before deletion so we know which cache key to evict
- Add tests for invalidate_user_provider_cache: removes sentinel/token entry,
  no-op when key absent, only removes the targeted key
2026-03-15 23:57:11 +07:00
Zamil Majdy
712aee7302 fix(copilot): warn on stale OAuth token fallback, document per-process cache
- Log at WARNING (not DEBUG) when OAuth refresh fails and we fall back to
  a potentially stale token, so operators can diagnose repeated auth failures
- Add multi-worker note to module docstring: _token_cache is process-local;
  each replica maintains its own cache (acceptable for current goal, but a
  shared cache would be needed for cross-replica efficiency)
2026-03-15 23:53:10 +07:00
Zamil Majdy
bef292033e fix(copilot): render error state in ConnectIntegrationTool
When part.state is 'output-error', show the error message from the
backend (ErrorResponse.message) in red text below the status line.
Without this, errors from unknown/unsupported providers were silently
discarded, leaving the user without any feedback.
2026-03-15 23:51:09 +07:00
Zamil Majdy
ec6974e3b8 fix(copilot): invalidate null cache on credential creation
When a user connects an integration, IntegrationCredentialsManager.create()
now calls invalidate_user_provider_cache() to remove any stale _NO_TOKEN
sentinel from the TTL cache. Without this, the first retry after connecting
would still return None for up to _NULL_CACHE_TTL (60 s).

The import is done lazily inside create() to avoid a circular import
between integrations.creds_manager and copilot.integration_creds.
2026-03-15 23:50:34 +07:00
Zamil Majdy
2ef5e2fe77 feat(copilot): bounded TTL cache with sentinel for integration creds
- Replace empty-string sentinel with explicit _NO_TOKEN = object() to
  avoid ambiguity with zero-length tokens
- Bound _token_cache to _CACHE_MAX_SIZE=10_000 entries; _cache_set()
  evicts oldest insertion-order entry when full
- Cache "not connected" results with _NULL_CACHE_TTL=60s (vs 300s for
  found tokens) to avoid a DB hit on every E2B bash_exec for users who
  haven't connected yet, while still picking up a new connection quickly
- Add integration_creds_test.py covering all cache paths, sentinel,
  eviction, OAuth2 preferred/fallback, DB exception, and env var injection
2026-03-15 23:46:05 +07:00
Zamil Majdy
0a8c7221ce fix(copilot): address all review findings
- prompting: rename _SDK_TOOL_NOTES → _E2B_TOOL_NOTES; pass it only to
  _get_cloud_sandbox_supplement() via new extra_notes param — local
  (bubblewrap) mode uses --unshare-net so gh CLI cannot reach GitHub
- integration_creds: cache None results with 60 s TTL (_NULL_CACHE_TTL)
  to avoid a DB hit on every E2B bash_exec for users without GitHub creds;
  found tokens still cached for 5 min (_TOKEN_CACHE_TTL)
- connect_integration: add cross-reference comment to PROVIDER_ENV_VARS
- ConnectIntegrationTool: use provider-specific credentialsLabel
  (e.g. "GitHub credentials" instead of "Integration credentials")
2026-03-15 23:35:25 +07:00
Zamil Majdy
840d1de636 refactor(copilot): move token injection to bash_exec + add integration_creds module
- Extract integration token lookup into backend/copilot/integration_creds.py:
  * Generic get_provider_token(user_id, provider) with 5-min TTL cache
  * get_integration_env_vars(user_id) loops over PROVIDER_ENV_VARS registry
  * Adding a new provider only requires a one-line PROVIDER_ENV_VARS entry

- Inject tokens lazily in bash_exec._execute_on_e2b (E2B has internet access;
  bubblewrap uses --unshare-net so gh CLI cannot reach GitHub regardless)

- Remove eager per-turn GH_TOKEN injection from sdk/service.py (wrong layer:
  bubblewrap is network-isolated, E2B injection now done per-command in bash_exec)

- Fix unsafe output.setup_info?.agent_name access in ConnectIntegrationTool

- Add connect_integration_test.py: unknown provider, known provider structure,
  reason in message, session_id propagation, case-insensitive provider slug
2026-03-15 23:29:33 +07:00
Zamil Majdy
ac55ab619b fix(copilot): address coderabbitai nitpicks
- Use `type Props` instead of `interface Props` in ConnectIntegrationTool
- Simplify parseOutput: skip stringify→parse round-trip for objects
- Document why requires_auth=True despite user_id not being used
2026-03-15 23:14:54 +07:00
Zamil Majdy
a8014d1e92 fix(copilot): address sentry — refresh expired OAuth tokens, handle object output in parseOutput
- service.py: use IntegrationCredentialsManager.refresh_if_needed() instead of
  raw IntegrationCredentialsStore so expired GitHub OAuth tokens are refreshed
  before injection; falls back to stale token on refresh failure to avoid
  breaking the turn entirely (lock=False to avoid blocking the turn)
- ConnectIntegrationTool.tsx: parseOutput now handles both string and
  already-parsed object inputs, matching the RunBlock helper pattern
  used elsewhere in the codebase
2026-03-15 23:06:57 +07:00
Zamil Majdy
7de13c7713 fix(copilot): address self-review — GH_TOKEN OAuth preference, unknown provider error, baseline note scope
- service.py: two-pass loop in _get_github_token_for_user() to genuinely
  prefer OAuth2 tokens over API keys; use creds.type discriminator instead
  of isinstance to match codebase style
- connect_integration.py: return ErrorResponse (not SetupRequirementsResponse)
  for unknown providers so the frontend renders a proper error instead of a
  blank broken card; trim "Integration" suffix from agent_name to avoid
  "Connect GitHub Integration" redundancy
- prompting.py: move GitHub CLI / connect_integration guidance from
  _SHARED_TOOL_NOTES (baseline+SDK) to _SDK_TOOL_NOTES (SDK-only) since
  baseline mode has no subprocess, no gh CLI, and no connect_integration tool
- ConnectIntegrationTool.tsx: simplify parseOutput to short-circuit when raw
  is not a string, removing unnecessary JSON.stringify round-trip
2026-03-15 23:00:46 +07:00
Zamil Majdy
9358b525a0 feat(copilot): inject GH_TOKEN and add connect_integration tool for missing GitHub credentials
When the user has connected GitHub, GH_TOKEN is automatically injected
into the Claude Agent SDK subprocess environment so `gh` CLI works
without any manual auth step.

When GitHub is not connected, the copilot can call the new
connect_integration(provider="github") MCP tool, which surfaces the
same credentials setup card used by GitHub blocks — letting the user
connect their account inline without leaving the chat.

- backend: _get_github_token_for_user() fetches the user's GitHub
  credentials (OAuth2 or API key) and injects GH_TOKEN + GITHUB_TOKEN
  into sdk_env before the Claude Agent SDK subprocess starts
- backend: ConnectIntegrationTool MCP tool returns a
  SetupRequirementsResponse for any known provider (github for now)
- backend: prompting.py documents the gh CLI / connect_integration
  flow in _SHARED_TOOL_NOTES so the copilot knows when to call it
- frontend: ConnectIntegrationTool component renders the existing
  SetupRequirementsCard with a tailored retry instruction
- frontend: MessagePartRenderer dispatches tool-connect_integration
  to the new component
2026-03-15 22:55:08 +07:00
26 changed files with 1248 additions and 1088 deletions

View File

@@ -17,17 +17,8 @@ from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
# Compiled UUID pattern for validating conversation directory names.
# Kept as a module-level constant so the security-relevant pattern is easy
# to audit in one place and avoids recompilation on every call.
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
@@ -44,20 +35,11 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The Claude CLI encodes the absolute cwd as a directory name by replacing
every non-alphanumeric character with ``-``. For example
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
"""
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
# Keep the private alias for internal callers (backwards compat).
_encode_cwd_for_cli = encode_cwd_for_cli
def set_execution_context(
user_id: str | None,
session: ChatSession,
@@ -118,9 +100,7 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
@@ -139,22 +119,10 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
encoded = _current_project_dir.get("")
if encoded:
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and tool-results/.
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] == "tool-results"
):
return True
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
SDK_PROJECTS_DIR,
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
@@ -104,13 +104,11 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
)
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
@@ -119,22 +117,10 @@ def test_is_allowed_local_path_tool_results_with_uuid():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
@@ -143,21 +129,6 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
_current_project_dir.set("")
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
)
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,162 @@
"""Integration credential lookup with per-process TTL cache.
Provides token retrieval for connected integrations so that copilot tools
(e.g. bash_exec) can inject auth tokens into the execution environment without
hitting the database on every command.
Cache semantics (handled automatically by TTLCache):
- Token found → cached for _TOKEN_CACHE_TTL (5 min). Avoids repeated DB hits
for users who have credentials and are running many bash commands.
- No credentials found → cached for _NULL_CACHE_TTL (60 s). Avoids a DB hit
on every E2B command for users who haven't connected an account yet, while
still picking up a newly-connected account within one minute.
Both caches are bounded to _CACHE_MAX_SIZE entries; cachetools evicts the
least-recently-used entry when the limit is reached.
Multi-worker note: both caches are in-process only. Each worker/replica
maintains its own independent cache, so a credential fetch may be duplicated
across processes. This is acceptable for the current goal (reduce DB hits per
session per-process), but if cache efficiency across replicas becomes important
a shared cache (e.g. Redis) should be used instead.
"""
import logging
from typing import cast
from cachetools import TTLCache
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
register_creds_changed_hook,
)
logger = logging.getLogger(__name__)
# Maps provider slug → env var names to inject when the provider is connected.
# Add new providers here when adding integration support.
# NOTE: keep in sync with connect_integration._PROVIDER_INFO — both registries
# must be updated when adding a new provider.
PROVIDER_ENV_VARS: dict[str, list[str]] = {
"github": ["GH_TOKEN", "GITHUB_TOKEN"],
}
_TOKEN_CACHE_TTL = 300.0 # seconds — for found tokens
_NULL_CACHE_TTL = 60.0 # seconds — for "not connected" results
_CACHE_MAX_SIZE = 10_000
# (user_id, provider) → token string. TTLCache handles expiry + eviction.
# Thread-safety note: TTLCache is NOT thread-safe, but that is acceptable here
# because all callers (get_provider_token, invalidate_user_provider_cache) run
# exclusively on the asyncio event loop. There are no await points between a
# cache read and its corresponding write within any function, so no concurrent
# coroutine can interleave. If ThreadPoolExecutor workers are ever added to
# this path, a threading.RLock should be wrapped around these caches.
_token_cache: TTLCache[tuple[str, str], str] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_TOKEN_CACHE_TTL
)
# Separate cache for "no credentials" results with a shorter TTL.
_null_cache: TTLCache[tuple[str, str], bool] = TTLCache(
maxsize=_CACHE_MAX_SIZE, ttl=_NULL_CACHE_TTL
)
def invalidate_user_provider_cache(user_id: str, provider: str) -> None:
"""Remove the cached entry for *user_id*/*provider* from both caches.
Call this after storing new credentials so that the next
``get_provider_token()`` call performs a fresh DB lookup instead of
serving a stale TTL-cached result.
"""
key = (user_id, provider)
_token_cache.pop(key, None)
_null_cache.pop(key, None)
# Register this module's cache-bust function with the credentials manager so
# that any create/update/delete operation immediately evicts stale cache
# entries. This avoids a lazy import inside creds_manager and eliminates the
# circular-import risk.
register_creds_changed_hook(invalidate_user_provider_cache)
# Module-level singleton to avoid re-instantiating IntegrationCredentialsManager
# on every cache-miss call to get_provider_token().
_manager = IntegrationCredentialsManager()
async def get_provider_token(user_id: str, provider: str) -> str | None:
"""Return the user's access token for *provider*, or ``None`` if not connected.
OAuth2 tokens are preferred (refreshed if needed); API keys are the fallback.
Found tokens are cached for _TOKEN_CACHE_TTL (5 min). "Not connected" results
are cached for _NULL_CACHE_TTL (60 s) to avoid a DB hit on every bash_exec
command for users who haven't connected yet, while still picking up a
newly-connected account within one minute.
"""
cache_key = (user_id, provider)
if cache_key in _null_cache:
return None
if cached := _token_cache.get(cache_key):
return cached
manager = _manager
try:
creds_list = await manager.store.get_creds_by_provider(user_id, provider)
except Exception:
logger.debug("Failed to fetch %s credentials for user %s", provider, user_id)
return None
# Pass 1: prefer OAuth2 (carry scope info, refreshable via token endpoint).
# Sort so broader-scoped tokens come first: a token with "repo" scope covers
# full git access, while a public-data-only token lacks push/pull permission.
# lock=False — background injection; not worth a distributed lock acquisition.
oauth2_creds = sorted(
[c for c in creds_list if c.type == "oauth2"],
key=lambda c: 0 if "repo" in (cast(OAuth2Credentials, c).scopes or []) else 1,
)
for creds in oauth2_creds:
if creds.type == "oauth2":
try:
fresh = await manager.refresh_if_needed(
user_id, cast(OAuth2Credentials, creds), lock=False
)
token = fresh.access_token.get_secret_value()
except Exception:
logger.warning(
"Failed to refresh %s OAuth token for user %s; "
"falling back to potentially stale token",
provider,
user_id,
)
token = cast(OAuth2Credentials, creds).access_token.get_secret_value()
_token_cache[cache_key] = token
return token
# Pass 2: fall back to API key (no expiry, no refresh needed).
for creds in creds_list:
if creds.type == "api_key":
token = cast(APIKeyCredentials, creds).api_key.get_secret_value()
_token_cache[cache_key] = token
return token
# No credentials found — cache to avoid repeated DB hits.
_null_cache[cache_key] = True
return None
async def get_integration_env_vars(user_id: str) -> dict[str, str]:
"""Return env vars for all providers the user has connected.
Iterates :data:`PROVIDER_ENV_VARS`, fetches each token, and builds a flat
``{env_var: token}`` dict ready to pass to a subprocess or E2B sandbox.
Only providers with a stored credential contribute entries.
"""
env: dict[str, str] = {}
for provider, var_names in PROVIDER_ENV_VARS.items():
token = await get_provider_token(user_id, provider)
if token:
for var in var_names:
env[var] = token
return env

View File

@@ -0,0 +1,193 @@
"""Tests for integration_creds — TTL cache and token lookup paths."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.copilot.integration_creds import (
_NULL_CACHE_TTL,
_TOKEN_CACHE_TTL,
PROVIDER_ENV_VARS,
_null_cache,
_token_cache,
get_integration_env_vars,
get_provider_token,
invalidate_user_provider_cache,
)
from backend.data.model import APIKeyCredentials, OAuth2Credentials
_USER = "user-integration-creds-test"
_PROVIDER = "github"
def _make_api_key_creds(key: str = "test-api-key") -> APIKeyCredentials:
return APIKeyCredentials(
id="creds-api-key",
provider=_PROVIDER,
api_key=SecretStr(key),
title="Test API Key",
expires_at=None,
)
def _make_oauth2_creds(token: str = "test-oauth-token") -> OAuth2Credentials:
return OAuth2Credentials(
id="creds-oauth2",
provider=_PROVIDER,
title="Test OAuth",
access_token=SecretStr(token),
refresh_token=SecretStr("test-refresh"),
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
)
@pytest.fixture(autouse=True)
def clear_caches():
"""Ensure clean caches before and after every test."""
_token_cache.clear()
_null_cache.clear()
yield
_token_cache.clear()
_null_cache.clear()
class TestInvalidateUserProviderCache:
def test_removes_token_entry(self):
key = (_USER, _PROVIDER)
_token_cache[key] = "tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _token_cache
def test_removes_null_entry(self):
key = (_USER, _PROVIDER)
_null_cache[key] = True
invalidate_user_provider_cache(_USER, _PROVIDER)
assert key not in _null_cache
def test_noop_when_key_not_cached(self):
# Should not raise even when there is no cache entry.
invalidate_user_provider_cache("no-such-user", _PROVIDER)
def test_only_removes_targeted_key(self):
other_key = ("other-user", _PROVIDER)
_token_cache[other_key] = "other-tok"
invalidate_user_provider_cache(_USER, _PROVIDER)
assert other_key in _token_cache
class TestGetProviderToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_cached_token_without_db_hit(self):
_token_cache[(_USER, _PROVIDER)] = "cached-tok"
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "cached-tok"
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_returns_none_for_null_cached_provider(self):
_null_cache[(_USER, _PROVIDER)] = True
mock_manager = MagicMock()
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
mock_manager.store.get_creds_by_provider.assert_not_called()
@pytest.mark.asyncio(loop_scope="session")
async def test_api_key_creds_returned_and_cached(self):
api_creds = _make_api_key_creds("my-api-key")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[api_creds])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "my-api-key"
assert _token_cache.get((_USER, _PROVIDER)) == "my-api-key"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_preferred_over_api_key(self):
oauth_creds = _make_oauth2_creds("oauth-tok")
api_creds = _make_api_key_creds("api-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
return_value=[api_creds, oauth_creds]
)
mock_manager.refresh_if_needed = AsyncMock(return_value=oauth_creds)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth2_refresh_failure_falls_back_to_stale_token(self):
oauth_creds = _make_oauth2_creds("stale-oauth-tok")
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[oauth_creds])
mock_manager.refresh_if_needed = AsyncMock(side_effect=RuntimeError("network"))
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result == "stale-oauth-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_caches_null_entry(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(return_value=[])
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
assert _null_cache.get((_USER, _PROVIDER)) is True
@pytest.mark.asyncio(loop_scope="session")
async def test_db_exception_returns_none_without_caching(self):
mock_manager = MagicMock()
mock_manager.store.get_creds_by_provider = AsyncMock(
side_effect=RuntimeError("db down")
)
with patch("backend.copilot.integration_creds._manager", mock_manager):
result = await get_provider_token(_USER, _PROVIDER)
assert result is None
# DB errors are not cached — next call will retry
assert (_USER, _PROVIDER) not in _token_cache
assert (_USER, _PROVIDER) not in _null_cache
@pytest.mark.asyncio(loop_scope="session")
async def test_null_cache_has_shorter_ttl_than_token_cache(self):
"""Verify the TTL constants are set correctly for each cache."""
assert _null_cache.ttl == _NULL_CACHE_TTL
assert _token_cache.ttl == _TOKEN_CACHE_TTL
assert _NULL_CACHE_TTL < _TOKEN_CACHE_TTL
class TestGetIntegrationEnvVars:
@pytest.mark.asyncio(loop_scope="session")
async def test_injects_all_env_vars_for_provider(self):
_token_cache[(_USER, "github")] = "gh-tok"
result = await get_integration_env_vars(_USER)
for var in PROVIDER_ENV_VARS["github"]:
assert result[var] == "gh-tok"
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_dict_when_no_credentials(self):
_null_cache[(_USER, "github")] = True
result = await get_integration_env_vars(_USER)
assert result == {}

View File

@@ -95,6 +95,25 @@ Example — committing an image file to GitHub:
All tasks must run in the foreground.
"""
# E2B-only notes — E2B has full internet access so gh CLI works there.
# Not shown in local (bubblewrap) mode: --unshare-net blocks all network.
_E2B_TOOL_NOTES = """
### GitHub CLI (`gh`) and git
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.
- For operations that need broader access (e.g. private org repos, GitHub
Actions), pass the required scopes: e.g.
`connect_integration(provider="github", scopes=["repo", "read:org"])`.
"""
# Environment-specific supplement templates
def _build_storage_supplement(
@@ -105,6 +124,7 @@ def _build_storage_supplement(
storage_system_1_persistence: list[str],
file_move_name_1_to_2: str,
file_move_name_2_to_1: str,
extra_notes: str = "",
) -> str:
"""Build storage/filesystem supplement for a specific environment.
@@ -119,6 +139,7 @@ def _build_storage_supplement(
storage_system_1_persistence: List of persistence behavior descriptions
file_move_name_1_to_2: Direction label for primary→persistent
file_move_name_2_to_1: Direction label for persistent→primary
extra_notes: Environment-specific notes appended after shared notes
"""
# Format lists as bullet points with proper indentation
characteristics = "\n".join(f" - {c}" for c in storage_system_1_characteristics)
@@ -152,19 +173,16 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `read_file` or `Read` (NOT `read_workspace_file`).
`read_workspace_file` reads from cloud workspace storage, where SDK
tool-results are NOT stored.
{_SHARED_TOOL_NOTES}"""
{_SHARED_TOOL_NOTES}{extra_notes}"""
# Pre-built supplements for common environments
def _get_local_storage_supplement(cwd: str) -> str:
"""Local ephemeral storage (files lost between turns)."""
"""Local ephemeral storage (files lost between turns).
Network is isolated (bubblewrap --unshare-net), so internet-dependent CLIs
like gh will not work — no integration env-var notes are included.
"""
return _build_storage_supplement(
working_dir=cwd,
sandbox_type="in a network-isolated sandbox",
@@ -182,7 +200,11 @@ def _get_local_storage_supplement(cwd: str) -> str:
def _get_cloud_sandbox_supplement() -> str:
"""Cloud persistent sandbox (files survive across turns in session)."""
"""Cloud persistent sandbox (files survive across turns in session).
E2B has full internet access, so integration tokens (GH_TOKEN etc.) are
injected per command in bash_exec — include the CLI guidance notes.
"""
return _build_storage_supplement(
working_dir="/home/user",
sandbox_type="in a cloud sandbox with full internet access",
@@ -197,6 +219,7 @@ def _get_cloud_sandbox_supplement() -> str:
],
file_move_name_1_to_2="Sandbox → Persistent",
file_move_name_2_to_1="Persistent → Sandbox",
extra_notes=_E2B_TOOL_NOTES,
)

View File

@@ -26,41 +26,6 @@ from backend.copilot.context import (
logger = logging.getLogger(__name__)
async def _check_sandbox_symlink_escape(
sandbox: Any,
parent: str,
) -> str | None:
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
``readlink -f`` follows actual symlinks on the sandbox filesystem.
Returns the canonical parent path, or ``None`` if the path escapes
``E2B_WORKDIR``.
Note: There is an inherent TOCTOU window between this check and the
subsequent ``sandbox.files.write()``. A symlink could theoretically be
replaced between the two operations. This is acceptable in the E2B
sandbox model since the sandbox is single-user and ephemeral.
"""
canonical_res = await sandbox.commands.run(
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
cwd=E2B_WORKDIR,
timeout=5,
)
canonical_parent = (canonical_res.stdout or "").strip()
if (
canonical_res.exit_code != 0
or not canonical_parent
or (
canonical_parent != E2B_WORKDIR
and not canonical_parent.startswith(E2B_WORKDIR + "/")
)
):
return None
return canonical_parent
def _get_sandbox():
return get_current_sandbox()
@@ -141,10 +106,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
@@ -169,12 +130,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
return result
sandbox, remote = result
parent = os.path.dirname(remote)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")

View File

@@ -4,19 +4,15 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import shutil
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
from .e2b_file_tools import (
_check_sandbox_symlink_escape,
_read_local,
resolve_sandbox_path,
)
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
@@ -25,48 +21,46 @@ from .e2b_file_tools import (
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert (
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
)
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within E2B_WORKDIR is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
"""Path that resolves back within /home/user is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
@@ -79,13 +73,9 @@ class TestResolveSandboxPath:
class TestReadLocal:
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
@@ -117,9 +107,7 @@ class TestReadLocal:
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
@@ -129,7 +117,7 @@ class TestReadLocal:
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
@@ -164,66 +152,3 @@ class TestReadLocal:
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True
# ---------------------------------------------------------------------------
# _check_sandbox_symlink_escape — symlink escape detection
# ---------------------------------------------------------------------------
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
return SimpleNamespace(commands=commands)
class TestCheckSandboxSymlinkEscape:
@pytest.mark.asyncio
async def test_canonical_path_within_workdir_returns_path(self):
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result == f"{E2B_WORKDIR}/src"
@pytest.mark.asyncio
async def test_workdir_itself_returns_workdir(self):
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
assert result == E2B_WORKDIR
@pytest.mark.asyncio
async def test_symlink_escape_returns_none(self):
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
assert result is None
@pytest.mark.asyncio
async def test_nonzero_exit_code_returns_none(self):
"""A non-zero exit code from readlink -f returns None."""
sandbox = _make_sandbox(stdout="", exit_code=1)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_empty_stdout_returns_none(self):
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
sandbox = _make_sandbox(stdout="", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_prefix_collision_returns_none(self):
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
assert result is None
@pytest.mark.asyncio
async def test_deeply_nested_path_within_workdir(self):
"""Deep nested paths inside E2B_WORKDIR are allowed."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
assert result == f"{E2B_WORKDIR}/a/b/c/d"

View File

@@ -42,7 +42,7 @@ def _validate_workspace_path(
Delegates to :func:`is_allowed_local_path` which permits:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
@@ -302,11 +302,7 @@ def create_security_hooks(
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
# Sanitize untrusted input before logging to prevent log injection
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")

View File

@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:

View File

@@ -10,7 +10,6 @@ import re
import shutil
import subprocess
import sys
import time
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
@@ -39,7 +38,6 @@ from backend.util.settings import Settings
from ..config import ChatConfig
from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
from ..context import encode_cwd_for_cli
from ..model import (
ChatMessage,
ChatSession,
@@ -77,7 +75,7 @@ from .tool_adapter import (
wait_for_stash,
)
from .transcript import (
cleanup_stale_project_dirs,
cleanup_cli_project_dir,
download_transcript,
read_compacted_entries,
upload_transcript,
@@ -145,9 +143,6 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
_last_sweep_time: float = 0.0
_SWEEP_INTERVAL_SECONDS = 300 # 5 minutes
# Heartbeat interval — keep SSE alive through proxies/LBs during tool execution.
# IMPORTANT: Must be less than frontend timeout (12s in useCopilotPage.ts)
_HEARTBEAT_INTERVAL = 10.0 # seconds
@@ -286,34 +281,31 @@ def _make_sdk_cwd(session_id: str) -> str:
return cwd
async def _cleanup_sdk_tool_results(cwd: str) -> None:
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK session artifacts for a specific working directory.
Cleans up the ephemeral working directory ``/tmp/copilot-<session>/``.
Also sweeps stale CLI project directories (older than 12 h) to prevent
unbounded disk growth. The sweep is best-effort, rate-limited to once
every 5 minutes, and capped at 50 directories per sweep.
Cleans up:
- ``~/.claude/projects/<encoded-cwd>/`` — CLI session transcripts and
tool-result files. Each SDK turn uses a unique cwd, so this directory
is safe to remove entirely.
- ``/tmp/copilot-<session>/`` — the ephemeral working directory.
Security: *cwd* MUST be created by ``_make_sdk_cwd()`` which sanitizes
the session_id.
"""
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning("[SDK] Rejecting cleanup for path outside workspace: %s", cwd)
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
await asyncio.to_thread(shutil.rmtree, normalized, True)
# Clean the CLI's project directory (transcripts + tool-results).
cleanup_cli_project_dir(cwd)
# Best-effort sweep of old project dirs to prevent disk leak.
# Pass the encoded cwd so only this session's project directory is swept,
# which is safe in multi-tenant environments.
global _last_sweep_time
now = time.time()
if now - _last_sweep_time >= _SWEEP_INTERVAL_SECONDS:
_last_sweep_time = now
encoded = encode_cwd_for_cli(normalized)
await asyncio.to_thread(cleanup_stale_project_dirs, encoded)
# Clean up the temp cwd directory itself.
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
@@ -777,7 +769,7 @@ async def stream_chat_completion_sdk(
)
return None
try:
return await get_or_create_sandbox(
sandbox = await get_or_create_sandbox(
session_id,
api_key=e2b_api_key,
template=config.e2b_sandbox_template,
@@ -791,7 +783,9 @@ async def stream_chat_completion_sdk(
e2b_err,
exc_info=True,
)
return None
return None
return sandbox
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
@@ -805,7 +799,7 @@ async def stream_chat_completion_sdk(
)
except Exception as transcript_err:
logger.warning(
"%s Transcript download failed, continuing without --resume: %s",
"%s Transcript download failed, continuing without " "--resume: %s",
log_prefix,
transcript_err,
)
@@ -828,7 +822,7 @@ async def stream_chat_completion_sdk(
is_valid = validate_transcript(dl.content)
dl_lines = dl.content.strip().split("\n") if dl.content else []
logger.info(
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
log_prefix,
len(dl.content),
len(dl_lines),
@@ -1062,7 +1056,8 @@ async def stream_chat_completion_sdk(
break
logger.info(
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
"%s Received: %s %s "
"(unresolved=%d, current=%d, resolved=%d)",
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
@@ -1107,14 +1102,7 @@ async def stream_chat_completion_sdk(
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
# 2.0 s timeout: the original 0.5 s caused frequent
# timeouts under load (parallel tool calls, large
# outputs). 2.0 s gives margin while still failing
# fast when the hook genuinely will not fire.
if await wait_for_stash(timeout=2.0):
# Yield once so any callbacks scheduled by the
# stash signal can propagate before we process
# the next SDK message.
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)
else:
logger.warning(
@@ -1500,14 +1488,11 @@ async def stream_chat_completion_sdk(
exc_info=True,
)
try:
if sdk_cwd:
await _cleanup_sdk_tool_results(sdk_cwd)
except Exception:
logger.warning("%s SDK cleanup failed", log_prefix, exc_info=True)
finally:
# Release stream lock to allow new streams for this session
await lock.release()
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
# Release stream lock to allow new streams for this session
await lock.release()
async def _update_title_async(

View File

@@ -288,90 +288,3 @@ class TestPromptSupplement:
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
# ---------------------------------------------------------------------------
# _cleanup_sdk_tool_results — orchestration + rate-limiting
# ---------------------------------------------------------------------------
class TestCleanupSdkToolResults:
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
_CWD_PREFIX = "/tmp/copilot-"
@pytest.mark.asyncio
async def test_removes_cwd_directory(self):
"""Cleanup removes the session working directory."""
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-cleanup-remove"
os.makedirs(cwd, exist_ok=True)
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
import backend.copilot.sdk.service as svc_mod
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
assert not os.path.exists(cwd)
@pytest.mark.asyncio
async def test_sweep_runs_when_interval_elapsed(self):
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-elapsed"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to a time far in the past
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_called_once()
@pytest.mark.asyncio
async def test_sweep_skipped_within_interval(self):
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
import time
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-ratelimit"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to now — interval not elapsed
svc_mod._last_sweep_time = time.time()
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_not_called()
@pytest.mark.asyncio
async def test_rejects_path_outside_prefix(self, tmp_path):
"""Cleanup rejects a cwd that does not start with the expected prefix."""
from .service import _cleanup_sdk_tool_results
evil_cwd = str(tmp_path / "evil-path")
os.makedirs(evil_cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
await _cleanup_sdk_tool_results(evil_cwd)
# Directory should NOT have been removed (rejected early)
assert os.path.exists(evil_cwd)
mock_sweep.assert_not_called()

View File

@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
event.set()
async def wait_for_stash(timeout: float = 2.0) -> bool:
async def wait_for_stash(timeout: float = 0.5) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
The 2.0 s default was chosen based on production metrics: the original
0.5 s caused frequent timeouts under load (parallel tool calls, large
outputs). 2.0 s gives a comfortable margin while still failing fast
when the hook genuinely will not fire.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net — normally the stash happens within
microseconds of yielding to the event loop.
"""
event = _stash_event.get(None)
if event is None:
@@ -285,7 +285,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved, encoding="utf-8", errors="replace") as f:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.

View File

@@ -151,110 +151,44 @@ def _projects_base() -> str:
return os.path.realpath(os.path.join(config_dir, "projects"))
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
directory. These are intentionally kept across turns so the model can read
tool-result files via ``--resume``. However, after a session ends they
become stale. This function sweeps old ones to prevent unbounded disk
growth.
When *encoded_cwd* is provided the sweep is scoped to that single
directory, making the operation safe in multi-tenant environments where
multiple copilot sessions share the same host. Without it the function
falls back to sweeping all directories matching the copilot naming pattern
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
Returns the number of directories removed.
Returns ``None`` if the path would escape the projects base.
"""
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
projects_base = _projects_base()
if not os.path.isdir(projects_base):
return 0
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
now = time.time()
removed = 0
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
if "-tmp-copilot-" not in target.name:
logger.warning(
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
)
return 0
try:
# st_mtime is used as a proxy for session activity. Claude CLI writes
# its JSONL transcript into this directory during each turn, so mtime
# advances on every turn. A directory whose mtime is older than
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
# and is safe to remove (the session cannot --resume after cleanup).
age = now - target.stat().st_mtime
except OSError:
return 0
if age < _STALE_PROJECT_DIR_SECONDS:
return 0
try:
shutil.rmtree(target, ignore_errors=True)
removed = 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
target.name,
int(age),
_STALE_PROJECT_DIR_SECONDS,
)
return removed
# Unscoped fallback: sweep all copilot dirs across the projects base.
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
for entry in entries:
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
break
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
# -private-tmp-copilot-).
if "-tmp-copilot-" not in entry.name:
continue
if not entry.is_dir():
continue
try:
# See the scoped-mode comment above: st_mtime advances on every turn,
# so a stale mtime reliably indicates an inactive session.
age = now - entry.stat().st_mtime
except OSError:
continue
if age < _STALE_PROJECT_DIR_SECONDS:
continue
try:
shutil.rmtree(entry, ignore_errors=True)
removed += 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
removed,
_STALE_PROJECT_DIR_SECONDS,
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] Project dir escaped projects base: %s", project_dir
)
return removed
return None
return project_dir
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
try:
resolved_base = Path(project_dir).resolve()
except OSError as e:
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
return []
result: list[Path] = []
for candidate in Path(project_dir).glob("*.jsonl"):
try:
resolved = candidate.resolve()
if resolved.is_relative_to(resolved_base):
result.append(resolved)
except (OSError, RuntimeError) as e:
logger.debug(
"[Transcript] Skipping invalid CLI session candidate %s: %s",
candidate,
e,
)
return result
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
@@ -321,6 +255,63 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
return entries
def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any compaction.
The CLI writes its session transcript to
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
Since each SDK turn uses a unique ``sdk_cwd``, there should be
exactly one ``.jsonl`` file in that directory.
Returns the file content, or ``None`` if not found.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = _safe_glob_jsonl(project_dir)
if not jsonl_files:
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
return None
# Pick the most recently modified file (should be only one per turn).
try:
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
except OSError as e:
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
return None
try:
content = session_file.read_text()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
session_file,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,

View File

@@ -9,7 +9,9 @@ from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
_cli_project_dir,
delete_transcript,
read_cli_session_file,
read_compacted_entries,
strip_progress_entries,
validate_transcript,
@@ -290,6 +292,85 @@ class TestStripProgressEntries:
assert asst_entry["parentUuid"] == "u1" # reparented
# --- read_cli_session_file ---
class TestReadCliSessionFile:
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
"""read_cli_session_file returns None when no .jsonl files exist."""
# Create a project dir with no jsonl files
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
assert read_cli_session_file("/fake/cwd") is None
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
"""read_cli_session_file returns the content of a single .jsonl file."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
jsonl_file = project_dir / "session.jsonl"
jsonl_file.write_text("line1\nline2\n")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
result = read_cli_session_file("/fake/cwd")
assert result == "line1\nline2\n"
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
"""read_cli_session_file skips symlinks that escape the project dir."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
# Create a file outside the project dir
outside = tmp_path / "outside"
outside.mkdir()
outside_file = outside / "evil.jsonl"
outside_file.write_text("should not be read\n")
# Symlink from inside project_dir to outside file
symlink = project_dir / "evil.jsonl"
symlink.symlink_to(outside_file)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
# The symlink target resolves outside project_dir, so it should be skipped
result = read_cli_session_file("/fake/cwd")
assert result is None
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
config_dir = tmp_path / "config"
config_dir.mkdir()
projects_dir = config_dir / "projects"
projects_dir.mkdir()
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# Create a symlink inside projects/ that points outside of it.
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
# cwd whose encoded form matches the symlink name we create.
evil_target = tmp_path / "escaped"
evil_target.mkdir()
# The encoded form of "/evil/cwd" is "-evil-cwd"
symlink_path = projects_dir / "-evil-cwd"
symlink_path.symlink_to(evil_target)
result = _cli_project_dir("/evil/cwd")
assert result is None
# --- delete_transcript ---
@@ -816,209 +897,3 @@ class TestCompactionFlowIntegration:
output2 = builder2.to_jsonl()
lines2 = [json.loads(line) for line in output2.strip().split("\n")]
assert lines2[-1]["parentUuid"] == "a2"
# ---------------------------------------------------------------------------
# cleanup_stale_project_dirs
# ---------------------------------------------------------------------------
class TestCleanupStaleProjectDirs:
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories matching copilot pattern older than threshold are removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
# Create a stale dir
stale = projects_dir / "-tmp-copilot-old-session"
stale.mkdir()
# Set mtime to past the threshold
import time
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale, (old_time, old_time))
# Create a fresh dir
fresh = projects_dir / "-tmp-copilot-new-session"
fresh.mkdir()
removed = cleanup_stale_project_dirs()
assert removed == 1
assert not stale.exists()
assert fresh.exists()
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories not matching copilot pattern are left alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
# Non-copilot dir that's old
import time
other = projects_dir / "some-other-project"
other.mkdir()
old_time = time.time() - 999999
os.utime(other, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert other.exists()
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
"""A directory exactly at the TTL boundary should NOT be removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
import time
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
boundary = projects_dir / "-tmp-copilot-boundary"
boundary.mkdir()
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
os.utime(boundary, (boundary_time, boundary_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert boundary.exists()
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
"""Regular files matching the copilot pattern are not removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
import time
# Create a regular FILE (not a dir) with the copilot pattern name
stale_file = projects_dir / "-tmp-copilot-stale-file"
stale_file.write_text("not a dir")
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale_file, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert stale_file.exists()
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
"""If the projects base directory doesn't exist, return 0 gracefully."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: nonexistent,
)
removed = cleanup_stale_project_dirs()
assert removed == 0
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
"""When encoded_cwd is supplied only that directory is swept."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
# Two stale copilot dirs
target = projects_dir / "-tmp-copilot-session-abc"
target.mkdir()
os.utime(target, (old_time, old_time))
other = projects_dir / "-tmp-copilot-session-xyz"
other.mkdir()
os.utime(other, (old_time, old_time))
# Only the target dir should be removed
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
assert removed == 1
assert not target.exists()
assert other.exists() # untouched — not the current session
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep leaves a fresh directory alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
fresh = projects_dir / "-tmp-copilot-session-new"
fresh.mkdir()
# mtime is now — well within TTL
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
assert removed == 0
assert fresh.exists()
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep refuses to remove a non-copilot directory."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
non_copilot = projects_dir / "some-other-project"
non_copilot.mkdir()
os.utime(non_copilot, (old_time, old_time))
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
assert removed == 0
assert non_copilot.exists()

View File

@@ -12,6 +12,7 @@ from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreensho
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
from .continue_run_block import ContinueRunBlockTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
@@ -84,6 +85,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
"connect_integration": ConnectIntegrationTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),

View File

@@ -22,6 +22,7 @@ from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.context import E2B_WORKDIR, get_current_sandbox
from backend.copilot.integration_creds import get_integration_env_vars
from backend.copilot.model import ChatSession
from .base import BaseTool
@@ -96,7 +97,9 @@ class BashExecTool(BaseTool):
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
return await self._execute_on_e2b(
sandbox, command, timeout, session_id, user_id
)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
@@ -133,14 +136,27 @@ class BashExecTool(BaseTool):
command: str,
timeout: int,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
"""Execute *command* on the E2B sandbox via commands.run().
Integration tokens (e.g. GH_TOKEN) are injected into the sandbox env
for any user with connected accounts. E2B has full internet access, so
CLI tools like ``gh`` work without manual authentication.
"""
envs: dict[str, str] = {
"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin",
}
if user_id is not None:
integration_env = await get_integration_env_vars(user_id)
envs.update(integration_env)
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
envs=envs,
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",

View File

@@ -0,0 +1,78 @@
"""Tests for BashExecTool — E2B path with token injection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from ._test_data import make_session
from .bash_exec import BashExecTool
from .models import BashExecResponse
_USER = "user-bash-exec-test"
def _make_tool() -> BashExecTool:
return BashExecTool()
def _make_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = "") -> MagicMock:
result = MagicMock()
result.exit_code = exit_code
result.stdout = stdout
result.stderr = stderr
sandbox = MagicMock()
sandbox.commands.run = AsyncMock(return_value=result)
return sandbox
class TestBashExecE2BTokenInjection:
@pytest.mark.asyncio(loop_scope="session")
async def test_token_injected_when_user_id_set(self):
"""When user_id is provided, integration env vars are merged into sandbox envs."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
env_vars = {"GH_TOKEN": "gh-secret", "GITHUB_TOKEN": "gh-secret"}
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value=env_vars),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=_USER,
)
mock_get_env.assert_awaited_once_with(_USER)
call_kwargs = sandbox.commands.run.call_args[1]
assert call_kwargs["envs"]["GH_TOKEN"] == "gh-secret"
assert call_kwargs["envs"]["GITHUB_TOKEN"] == "gh-secret"
assert isinstance(result, BashExecResponse)
@pytest.mark.asyncio(loop_scope="session")
async def test_no_token_injection_when_user_id_is_none(self):
"""When user_id is None, get_integration_env_vars must NOT be called."""
tool = _make_tool()
session = make_session(user_id=_USER)
sandbox = _make_sandbox(stdout="ok")
with patch(
"backend.copilot.tools.bash_exec.get_integration_env_vars",
new=AsyncMock(return_value={"GH_TOKEN": "should-not-appear"}),
) as mock_get_env:
result = await tool._execute_on_e2b(
sandbox=sandbox,
command="echo hi",
timeout=10,
session_id=session.session_id,
user_id=None,
)
mock_get_env.assert_not_called()
call_kwargs = sandbox.commands.run.call_args[1]
assert "GH_TOKEN" not in call_kwargs["envs"]
assert isinstance(result, BashExecResponse)

View File

@@ -0,0 +1,215 @@
"""Tool for prompting the user to connect a required integration.
When the copilot encounters an authentication failure (e.g. `gh` CLI returns
"authentication required"), it calls this tool to surface the credentials
setup card in the chat — the same UI that appears when a GitHub block runs
without configured credentials.
"""
import functools
from typing import Any, TypedDict
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import (
ErrorResponse,
ResponseType,
SetupInfo,
SetupRequirementsResponse,
ToolResponseBase,
UserReadiness,
)
from .base import BaseTool
class _ProviderInfo(TypedDict):
name: str
types: list[str]
# Default OAuth scopes requested when the agent doesn't specify any.
scopes: list[str]
class _CredentialEntry(TypedDict):
"""Shape of each entry inside SetupRequirementsResponse.user_readiness.missing_credentials."""
id: str
title: str
provider: str
provider_name: str
type: str
types: list[str]
scopes: list[str]
@functools.lru_cache(maxsize=1)
def _is_github_oauth_configured() -> bool:
"""Return True if GitHub OAuth env vars are set.
Evaluated lazily (not at import time) to avoid triggering Secrets() during
module import, which can fail in environments where secrets are not loaded.
"""
from backend.blocks.github._auth import GITHUB_OAUTH_IS_CONFIGURED
return GITHUB_OAUTH_IS_CONFIGURED
# Registry of known providers: name + supported credential types for the UI.
# When adding a new provider, also add its env var names to
# backend.copilot.integration_creds.PROVIDER_ENV_VARS.
def _get_provider_info() -> dict[str, _ProviderInfo]:
"""Build the provider registry, evaluating OAuth config lazily."""
return {
"github": {
"name": "GitHub",
"types": (
["api_key", "oauth2"] if _is_github_oauth_configured() else ["api_key"]
),
# Default: repo scope covers clone/push/pull for public and private repos.
# Agent can request additional scopes (e.g. "read:org") via the scopes param.
"scopes": ["repo"],
},
}
class ConnectIntegrationTool(BaseTool):
"""Surface the credentials setup UI when an integration is not connected."""
@property
def name(self) -> str:
return "connect_integration"
@property
def description(self) -> str:
return (
"Prompt the user to connect a required integration (e.g. GitHub). "
"Call this when an external CLI or API call fails because the user "
"has not connected the relevant account. "
"The tool surfaces a credentials setup card in the chat so the user "
"can authenticate without leaving the page. "
"After the user connects the account, retry the operation. "
"In E2B/cloud sandbox mode the token (GH_TOKEN/GITHUB_TOKEN) is "
"automatically injected per-command in bash_exec — no manual export needed. "
"In local bubblewrap mode network is isolated so GitHub CLI commands "
"will still fail after connecting; inform the user of this limitation."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"provider": {
"type": "string",
"description": (
"Integration provider slug, e.g. 'github'. "
"Must be one of the supported providers."
),
"enum": list(_get_provider_info().keys()),
},
"reason": {
"type": "string",
"description": (
"Brief explanation of why the integration is needed, "
"shown to the user in the setup card."
),
"maxLength": 500,
},
"scopes": {
"type": "array",
"items": {"type": "string"},
"description": (
"OAuth scopes to request. Omit to use the provider default. "
"Add extra scopes when you need more access — e.g. for GitHub: "
"'repo' (clone/push/pull), 'read:org' (org membership), "
"'workflow' (GitHub Actions). "
"Requesting only the scopes you actually need is best practice."
),
},
},
"required": ["provider"],
}
@property
def requires_auth(self) -> bool:
# Require auth so only authenticated users can trigger the setup card.
# The card itself is user-agnostic (no per-user data needed), so
# user_id is intentionally unused in _execute.
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
del user_id # setup card is user-agnostic; auth is enforced via requires_auth
session_id = session.session_id if session else None
provider: str = (kwargs.get("provider") or "").strip().lower()
reason: str = (kwargs.get("reason") or "").strip()[
:500
] # cap LLM-controlled text
extra_scopes: list[str] = [
str(s).strip() for s in (kwargs.get("scopes") or []) if str(s).strip()
]
provider_info = _get_provider_info()
info = provider_info.get(provider)
if not info:
supported = ", ".join(f"'{p}'" for p in provider_info)
return ErrorResponse(
message=(
f"Unknown provider '{provider}'. "
f"Supported providers: {supported}."
),
error="unknown_provider",
session_id=session_id,
)
provider_name: str = info["name"]
supported_types: list[str] = info["types"]
# Merge agent-requested scopes with provider defaults (deduplicated, order preserved).
default_scopes: list[str] = info["scopes"]
seen: set[str] = set()
scopes: list[str] = []
for s in default_scopes + extra_scopes:
if s not in seen:
seen.add(s)
scopes.append(s)
field_key = f"{provider}_credentials"
message_parts = [
f"To continue, please connect your {provider_name} account.",
]
if reason:
message_parts.append(reason)
credential_entry: _CredentialEntry = {
"id": field_key,
"title": f"{provider_name} Credentials",
"provider": provider,
"provider_name": provider_name,
"type": supported_types[0],
"types": supported_types,
"scopes": scopes,
}
missing_credentials: dict[str, _CredentialEntry] = {field_key: credential_entry}
return SetupRequirementsResponse(
type=ResponseType.SETUP_REQUIREMENTS,
message=" ".join(message_parts),
session_id=session_id,
setup_info=SetupInfo(
agent_id=f"connect_{provider}",
agent_name=provider_name,
user_readiness=UserReadiness(
has_all_credentials=False,
missing_credentials=missing_credentials,
ready_to_run=False,
),
requirements={
"credentials": [missing_credentials[field_key]],
"inputs": [],
"execution_modes": [],
},
),
)

View File

@@ -0,0 +1,135 @@
"""Tests for ConnectIntegrationTool."""
import pytest
from ._test_data import make_session
from .connect_integration import ConnectIntegrationTool
from .models import ErrorResponse, SetupRequirementsResponse
_TEST_USER_ID = "test-user-connect-integration"
class TestConnectIntegrationTool:
def _make_tool(self) -> ConnectIntegrationTool:
return ConnectIntegrationTool()
@pytest.mark.asyncio(loop_scope="session")
async def test_unknown_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="nonexistent"
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
assert "nonexistent" in result.message
assert "github" in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_empty_provider_returns_error(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider=""
)
assert isinstance(result, ErrorResponse)
assert result.error == "unknown_provider"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_provider_returns_setup_response(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.setup_info.agent_name == "GitHub"
assert result.setup_info.agent_id == "connect_github"
@pytest.mark.asyncio(loop_scope="session")
async def test_github_has_missing_credentials_in_readiness(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
readiness = result.setup_info.user_readiness
assert readiness.has_all_credentials is False
assert readiness.ready_to_run is False
assert "github_credentials" in readiness.missing_credentials
@pytest.mark.asyncio(loop_scope="session")
async def test_github_requirements_include_credential_entry(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
creds = result.setup_info.requirements["credentials"]
assert len(creds) == 1
assert creds[0]["provider"] == "github"
assert creds[0]["id"] == "github_credentials"
@pytest.mark.asyncio(loop_scope="session")
async def test_reason_appears_in_message(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
reason = "Needed to create a pull request."
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github", reason=reason
)
assert isinstance(result, SetupRequirementsResponse)
assert reason in result.message
@pytest.mark.asyncio(loop_scope="session")
async def test_session_id_propagated(self):
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="github"
)
assert isinstance(result, SetupRequirementsResponse)
assert result.session_id == session.session_id
@pytest.mark.asyncio(loop_scope="session")
async def test_provider_case_insensitive(self):
"""Provider slug is normalised to lowercase before lookup."""
tool = self._make_tool()
session = make_session(user_id=_TEST_USER_ID)
result = await tool._execute(
user_id=_TEST_USER_ID, session=session, provider="GitHub"
)
assert isinstance(result, SetupRequirementsResponse)
def test_tool_name(self):
assert ConnectIntegrationTool().name == "connect_integration"
def test_requires_auth(self):
assert ConnectIntegrationTool().requires_auth is True
@pytest.mark.asyncio(loop_scope="session")
async def test_unauthenticated_user_gets_need_login_response(self):
"""execute() with user_id=None must return NeedLoginResponse, not the setup card.
This verifies that the requires_auth guard in BaseTool.execute() fires
before _execute() is called, so unauthenticated callers cannot probe
which integrations are configured.
"""
import json
tool = self._make_tool()
# Session still needs a user_id string; the None is passed to execute()
# to simulate an unauthenticated call.
session = make_session(user_id=_TEST_USER_ID)
result = await tool.execute(
user_id=None,
session=session,
tool_call_id="test-call-id",
provider="github",
)
raw = result.output
output = json.loads(raw) if isinstance(raw, str) else raw
assert output.get("type") == "need_login"
assert result.success is False

View File

@@ -2,7 +2,6 @@
import base64
import logging
import mimetypes
import os
from typing import Any, Optional
@@ -11,9 +10,7 @@ from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
@@ -27,10 +24,6 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
# Sentinel file_id used when a tool-result file is read directly from the local
# host filesystem (rather than from workspace storage).
_LOCAL_TOOL_RESULT_FILE_ID = "local"
async def _resolve_write_content(
content_text: str | None,
@@ -282,93 +275,6 @@ class WorkspaceFileContentResponse(ToolResponseBase):
content_base64: str
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
def _read_local_tool_result(
path: str,
char_offset: int,
char_length: Optional[int],
session_id: str,
sdk_cwd: str | None = None,
) -> ToolResponseBase:
"""Read an SDK tool-result file from local disk.
This is a fallback for when the model mistakenly calls
``read_workspace_file`` with an SDK tool-result path that only exists on
the host filesystem, not in cloud workspace storage.
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
regardless of what the caller has already checked.
"""
# TOCTOU: path validated then opened separately. Acceptable because
# the tool-results directory is server-controlled, not user-writable.
expanded = os.path.realpath(os.path.expanduser(path))
# Defence-in-depth: re-check with resolved path (caller checked raw path).
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
return ErrorResponse(
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
)
try:
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
# Pre-read size check prevents loading files far above the cap;
# the remaining TOCTOU gap is acceptable for server-controlled paths.
file_size = os.path.getsize(expanded)
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
return ErrorResponse(
message=(f"File too large: {os.path.basename(path)}"),
session_id=session_id,
)
# Detect binary files: try strict UTF-8 first, fall back to
# base64-encoding the raw bytes for binary content.
with open(expanded, "rb") as fh:
raw = fh.read()
try:
text_content = raw.decode("utf-8")
except UnicodeDecodeError:
# Binary file — return raw base64, ignore char_offset/char_length
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
content_base64=base64.b64encode(raw).decode("ascii"),
message=(
f"Read {file_size:,} bytes (binary) from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
end = (
char_offset + char_length if char_length is not None else len(text_content)
)
slice_text = text_content[char_offset:end]
except FileNotFoundError:
return ErrorResponse(
message=f"File not found: {os.path.basename(path)}", session_id=session_id
)
except Exception as exc:
return ErrorResponse(
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
)
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
message=(
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
f"of {len(text_content):,} chars from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
@@ -627,14 +533,6 @@ class ReadWorkspaceFileTool(BaseTool):
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
# Fallback: if the path is an SDK tool-result on local disk,
# read it directly instead of failing. The model sometimes
# calls read_workspace_file for these paths by mistake.
sdk_cwd = get_sdk_cwd()
if path and is_allowed_local_path(path, sdk_cwd):
return _read_local_tool_result(
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
)
return resolved
target_file_id, file_info = resolved

View File

@@ -2,25 +2,18 @@
import base64
import os
import shutil
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
from backend.copilot.tools._test_data import make_session, setup_test_data
from backend.copilot.tools.models import ErrorResponse
from backend.copilot.tools.workspace_files import (
_MAX_LOCAL_TOOL_RESULT_BYTES,
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WorkspaceDeleteResponse,
WorkspaceFileContentResponse,
WorkspaceFileListResponse,
WorkspaceWriteResponse,
WriteWorkspaceFileTool,
_read_local_tool_result,
_resolve_write_content,
_validate_ephemeral_path,
)
@@ -332,294 +325,3 @@ async def test_write_workspace_file_source_path(setup_test_data):
await delete_tool._execute(
user_id=user.id, session=session, file_id=write_resp.file_id
)
# ---------------------------------------------------------------------------
# _read_local_tool_result — local disk fallback for SDK tool-result files
# ---------------------------------------------------------------------------
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
class TestReadLocalToolResult:
"""Tests for _read_local_tool_result (local disk fallback)."""
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
"""Create a tool-results file and return its path."""
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, filename)
with open(filepath, "wb") as f:
f.write(content)
return filepath
def _cleanup(self, encoded: str) -> None:
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_text_file(self):
"""Read a UTF-8 text tool-result file."""
encoded = "-tmp-copilot-local-read-text"
path = self._make_tool_result(encoded, "output.txt", b"hello world")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "hello world"
assert "text/plain" in result.mime_type
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_text_with_offset(self):
"""Read a slice of a text file using char_offset and char_length."""
encoded = "-tmp-copilot-local-read-offset"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 3, 4, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "DEFG"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_binary_file(self):
"""Binary files are returned as raw base64."""
encoded = "-tmp-copilot-local-read-binary"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "image.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64)
assert decoded == binary_data
assert "binary" in result.message
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_disallowed_path_rejected(self):
"""Paths not under allowed directories are rejected."""
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not allowed" in result.message.lower()
def test_file_not_found(self):
"""Missing files return an error."""
encoded = "-tmp-copilot-local-read-missing"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
path = os.path.join(tool_dir, "nope.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not found" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_file_too_large(self, monkeypatch):
"""Files exceeding the size limit are rejected."""
encoded = "-tmp-copilot-local-read-large"
# Create a small file but fake os.path.getsize to return a huge value
path = self._make_tool_result(encoded, "big.txt", b"small")
token = _current_project_dir.set(encoded)
monkeypatch.setattr(
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "too large" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_beyond_file_length(self):
"""Offset past end-of-file returns empty content."""
encoded = "-tmp-copilot-local-read-past-eof"
path = self._make_tool_result(encoded, "short.txt", b"abc")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 999, 10, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_zero_length_read(self):
"""Requesting zero characters returns empty content."""
encoded = "-tmp-copilot-local-read-zero-len"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 2, 0, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_json_extension(self):
"""JSON files get application/json MIME type, not hardcoded text/plain."""
encoded = "-tmp-copilot-local-read-json"
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "application/json"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_png_extension(self):
"""Binary .png files get image/png MIME type via mimetypes."""
encoded = "-tmp-copilot-local-read-png-mime"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "chart.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "image/png"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_explicit_sdk_cwd_parameter(self):
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
encoded = "-tmp-copilot-local-read-sdkcwd"
path = self._make_tool_result(encoded, "out.txt", b"content")
token = _current_project_dir.set(encoded)
try:
# Pass sdk_cwd explicitly — should still succeed because the path
# is under SDK_PROJECTS_DIR which is always allowed.
result = _read_local_tool_result(
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
)
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "content"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_with_no_length_reads_to_end(self):
"""When char_length is None, read from offset to end of file."""
encoded = "-tmp-copilot-local-read-offset-noLen"
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 5, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "56789"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
# ---------------------------------------------------------------------------
# ReadWorkspaceFileTool fallback to _read_local_tool_result
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
"""When _resolve_file returns ErrorResponse for an allowed local path,
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
user = setup_test_data["user"]
session = make_session(user.id)
# Create a real tool-result file on disk so the fallback can read it.
encoded = "-tmp-copilot-fallback-test"
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, "result.txt")
with open(filepath, "w") as f:
f.write("fallback content")
token = _current_project_dir.set(encoded)
try:
# Mock _resolve_file to return an ErrorResponse (simulating "file not
# found in workspace") so the fallback branch is exercised.
mock_resolve = AsyncMock(
return_value=ErrorResponse(
message="File not found at path: result.txt",
session_id=session.session_id,
)
)
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
path=filepath,
)
# Should have fallen back to _read_local_tool_result and succeeded.
assert isinstance(result, WorkspaceFileContentResponse), (
f"Expected fallback to local read, got {type(result).__name__}: "
f"{getattr(result, 'message', '')}"
)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "fallback content"
mock_resolve.assert_awaited_once()
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
user = setup_test_data["user"]
session = make_session(user.id)
fake_file_id = "fake-file-id-001"
fake_content = b"workspace content"
# Build a minimal file_info stub that the tool's happy-path needs.
class _FakeFileInfo:
id = fake_file_id
name = "result.json"
path = "/result.json"
mime_type = "text/plain"
size_bytes = len(fake_content)
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
mock_manager = AsyncMock()
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
with (
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
patch(
"backend.copilot.tools.workspace_files.get_workspace_manager",
AsyncMock(return_value=mock_manager),
),
patch(
"backend.copilot.tools.workspace_files._read_local_tool_result"
) as patched_local,
):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
file_id=fake_file_id,
)
# Fallback must not have been called.
patched_local.assert_not_called()
# Normal workspace path must have produced a content response.
assert isinstance(result, WorkspaceFileContentResponse)
assert base64.b64decode(result.content_base64) == fake_content

View File

@@ -25,6 +25,35 @@ logger = logging.getLogger(__name__)
settings = Settings()
_on_creds_changed: Callable[[str, str], None] | None = None
def register_creds_changed_hook(hook: Callable[[str, str], None]) -> None:
"""Register a callback invoked after any credential is created/updated/deleted.
The callback receives ``(user_id, provider)`` and should be idempotent.
Only one hook can be registered at a time; calling this again replaces the
previous hook. Intended to be called once at application startup by the
copilot module to bust its token cache without creating an import cycle.
"""
global _on_creds_changed
_on_creds_changed = hook
def _bust_copilot_cache(user_id: str, provider: str) -> None:
"""Invoke the registered hook (if any) to bust downstream token caches."""
if _on_creds_changed is not None:
try:
_on_creds_changed(user_id, provider)
except Exception:
logger.warning(
"Credential-change hook failed for user=%s provider=%s",
user_id,
provider,
exc_info=True,
)
class IntegrationCredentialsManager:
"""
Handles the lifecycle of integration credentials.
@@ -69,7 +98,11 @@ class IntegrationCredentialsManager:
return self._locks
async def create(self, user_id: str, credentials: Credentials) -> None:
return await self.store.add_creds(user_id, credentials)
result = await self.store.add_creds(user_id, credentials)
# Bust the copilot token cache so that the next bash_exec picks up the
# new credential immediately instead of waiting for _NULL_CACHE_TTL.
_bust_copilot_cache(user_id, credentials.provider)
return result
async def exists(self, user_id: str, credentials_id: str) -> bool:
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
@@ -156,6 +189,8 @@ class IntegrationCredentialsManager:
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
# Bust copilot cache so the refreshed token is picked up immediately.
_bust_copilot_cache(user_id, fresh_credentials.provider)
if _lock and (await _lock.locked()) and (await _lock.owned()):
try:
await _lock.release()
@@ -168,10 +203,17 @@ class IntegrationCredentialsManager:
async def update(self, user_id: str, updated: Credentials) -> None:
async with self._locked(user_id, updated.id):
await self.store.update_creds(user_id, updated)
# Bust the copilot token cache so the updated credential is picked up immediately.
_bust_copilot_cache(user_id, updated.provider)
async def delete(self, user_id: str, credentials_id: str) -> None:
async with self._locked(user_id, credentials_id):
# Read inside the lock to avoid TOCTOU — another coroutine could
# delete the same credential between the read and the delete.
creds = await self.store.get_creds_by_id(user_id, credentials_id)
await self.store.delete_creds_by_id(user_id, credentials_id)
if creds:
_bust_copilot_cache(user_id, creds.provider)
# -- Locking utilities -- #

View File

@@ -3,6 +3,7 @@ import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { ExclamationMarkIcon } from "@phosphor-icons/react";
import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
import { useState } from "react";
import { ConnectIntegrationTool } from "../../../tools/ConnectIntegrationTool/ConnectIntegrationTool";
import { CreateAgentTool } from "../../../tools/CreateAgent/CreateAgent";
import { EditAgentTool } from "../../../tools/EditAgent/EditAgent";
import {
@@ -129,6 +130,8 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
case "tool-search_docs":
case "tool-get_doc_page":
return <SearchDocsTool key={key} part={part as ToolUIPart} />;
case "tool-connect_integration":
return <ConnectIntegrationTool key={key} part={part as ToolUIPart} />;
case "tool-run_block":
case "tool-continue_run_block":
return <RunBlockTool key={key} part={part as ToolUIPart} />;

View File

@@ -0,0 +1,104 @@
"use client";
import type { SetupRequirementsResponse } from "@/app/api/__generated__/models/setupRequirementsResponse";
import type { ToolUIPart } from "ai";
import { useState } from "react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import { SetupRequirementsCard } from "../RunBlock/components/SetupRequirementsCard/SetupRequirementsCard";
type Props = {
part: ToolUIPart;
};
function parseJson(raw: unknown): unknown {
if (typeof raw === "string") {
try {
return JSON.parse(raw);
} catch {
return null;
}
}
return raw;
}
function parseOutput(raw: unknown): SetupRequirementsResponse | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "setup_info" in parsed) {
return parsed as SetupRequirementsResponse;
}
return null;
}
function parseError(raw: unknown): string | null {
const parsed = parseJson(raw);
if (parsed && typeof parsed === "object" && "message" in parsed) {
return String((parsed as { message: unknown }).message);
}
return null;
}
export function ConnectIntegrationTool({ part }: Props) {
// Persist dismissed state here so SetupRequirementsCard remounts don't re-enable Proceed.
const [isDismissed, setIsDismissed] = useState(false);
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
const output =
part.state === "output-available"
? parseOutput((part as { output?: unknown }).output)
: null;
const errorMessage = isError
? (parseError((part as { output?: unknown }).output) ??
"Failed to connect integration")
: null;
const rawProvider =
(part as { input?: { provider?: string } }).input?.provider ?? "";
const providerName =
output?.setup_info?.agent_name ??
// Sanitize LLM-controlled provider slug: trim and cap at 64 chars to
// prevent runaway text in the DOM.
(rawProvider ? rawProvider.trim().slice(0, 64) : "integration");
const label = isStreaming
? `Connecting ${providerName}`
: isError
? `Failed to connect ${providerName}`
: output
? `Connect ${output.setup_info?.agent_name ?? providerName}`
: `Connect ${providerName}`;
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<MorphingTextAnimation
text={label}
className={isError ? "text-red-500" : undefined}
/>
</div>
{isError && errorMessage && (
<p className="mt-1 text-sm text-red-500">{errorMessage}</p>
)}
{output && (
<div className="mt-2">
{isDismissed ? (
<ContentMessage>Connected. Continuing</ContentMessage>
) : (
<SetupRequirementsCard
output={output}
credentialsLabel={`${output.setup_info?.agent_name ?? providerName} credentials`}
retryInstruction="I've connected my account. Please continue."
onComplete={() => setIsDismissed(true)}
/>
)}
</div>
)}
</div>
);
}

View File

@@ -23,12 +23,16 @@ interface Props {
/** Override the label shown above the credentials section.
* Defaults to "Credentials". */
credentialsLabel?: string;
/** Called after Proceed is clicked so the parent can persist the dismissed state
* across remounts (avoids re-enabling the Proceed button on remount). */
onComplete?: () => void;
}
export function SetupRequirementsCard({
output,
retryInstruction,
credentialsLabel,
onComplete,
}: Props) {
const { onSend } = useCopilotChatActions();
@@ -68,13 +72,17 @@ export function SetupRequirementsCard({
return v !== undefined && v !== null && v !== "";
});
if (hasSent) {
return <ContentMessage>Connected. Continuing</ContentMessage>;
}
const canRun =
!hasSent &&
(!needsCredentials || isAllCredentialsComplete) &&
(!needsInputs || isAllInputsComplete);
function handleRun() {
setHasSent(true);
onComplete?.();
const parts: string[] = [];
if (needsCredentials) {

View File

@@ -125,9 +125,9 @@ export function useCredentialsInput({
if (hasAttemptedAutoSelect.current) return;
hasAttemptedAutoSelect.current = true;
// Auto-select if exactly one credential matches.
// For optional fields with multiple options, let the user choose.
if (isOptional && savedCreds.length > 1) return;
// Auto-select only when there is exactly one saved credential.
// With multiple options the user must choose — regardless of optional/required.
if (savedCreds.length > 1) return;
const cred = savedCreds[0];
onSelectCredential({