Compare commits

..

25 Commits

Author SHA1 Message Date
Zamil Majdy
885459c6e1 fix(backend/copilot): respect CLAUDE_CONFIG_DIR in SDK_PROJECTS_DIR constant
SDK_PROJECTS_DIR was hardcoded to ~/.claude/projects, ignoring the
CLAUDE_CONFIG_DIR environment variable. This caused path validation
mismatches in environments with custom Claude configurations.
Now consistent with transcript.py's _projects_base() function.
2026-03-17 13:49:45 +07:00
Zamil Majdy
38ff768a65 fix(backend/copilot): update test mock to use get_workspace_manager
The dev branch renamed get_manager to get_workspace_manager but the
test was still patching the old name, causing an AttributeError.
2026-03-17 12:00:34 +07:00
Zamil Majdy
76dbf3bbec Merge origin/dev into fix/copilot-tool-result-read
Resolve import conflict in workspace_files.py: keep both
get_workspace_manager (from dev) and get_sdk_cwd/is_allowed_local_path
(from this PR).
2026-03-17 07:11:15 +07:00
Zamil Majdy
c0ade4be68 fix(copilot): patch mock at service module level after top-level import refactor
The test was patching backend.copilot.sdk.transcript.cleanup_stale_project_dirs
but service.py now imports it at module level, creating its own binding. Patch
the symbol at the call site (service module) instead.
2026-03-16 06:27:43 +07:00
Zamil Majdy
24cbe738ff fix(copilot): address review items — top-level import, path sanitization, E2B_WORKDIR constant, st_mtime comment, no-fallback test
- Move `cleanup_stale_project_dirs` from deferred import inside `_cleanup_sdk_tool_results` to the top-level `from .transcript import (...)` block
- Sanitize `FileNotFoundError` message in `_read_local_tool_result` to use `os.path.basename(path)` instead of leaking the full path
- Replace hardcoded `/home/user` strings in `e2b_file_tools_test.py` with the `E2B_WORKDIR` constant
- Add `st_mtime` write-once invariant comment to `cleanup_stale_project_dirs` explaining why mtime reliably signals session activity
- Add test asserting the local-disk fallback is NOT invoked when `_resolve_file` succeeds
2026-03-16 06:15:59 +07:00
Zamil Majdy
49e031b54d fix(copilot): use _LOCAL_TOOL_RESULT_FILE_ID constant for text path in _read_local_tool_result
Replace remaining hardcoded "local" string with the named constant
_LOCAL_TOOL_RESULT_FILE_ID in the text-file return path of
_read_local_tool_result, completing the previous fix that only updated
the binary-file return path.
2026-03-15 22:55:59 +07:00
Zamil Majdy
4b6c2a1323 fix(backend/copilot): scope stale project-dir sweep to current session and expose encode_cwd_for_cli
Addresses multi-tenant safety concern: cleanup_stale_project_dirs now accepts
an optional encoded_cwd parameter that limits the sweep to just the current
session's directory instead of all ~/.claude/projects/ entries. Exposes
encode_cwd_for_cli as a public function from context.py and passes the encoded
cwd from _cleanup_sdk_tool_results. Adds three new tests covering scoped
sweep behaviour.
2026-03-15 22:44:48 +07:00
Zamil Majdy
c854c1a485 fix(copilot): address review items — symlink check for edit_file, local constant, sort removal, and unit tests
- Extend _check_sandbox_symlink_escape to _handle_edit_file for consistency
- Define _LOCAL_TOOL_RESULT_FILE_ID constant, replacing magic string "local"
- Replace sorted(Path.iterdir()) with plain iterdir() in cleanup_stale_project_dirs
- Add TestCheckSandboxSymlinkEscape unit tests (7 cases) in e2b_file_tools_test.py
- Add TestCleanupSdkToolResults unit tests (4 cases) in service_test.py covering rate-limiting and path rejection
2026-03-15 22:20:29 +07:00
Zamil Majdy
53a2c84796 fix(backend/copilot): add tests for local tool result reading and stale dir sweep 2026-03-15 04:10:05 +07:00
Zamil Majdy
3063ce22ac fix(copilot): add inline comments for timeout rationale and sweep safety
Address remaining PR review comments:
- Document 2.0s timeout reasoning at call site (was 0.5s, caused
  frequent timeouts under load)
- Document sleep(0) yield purpose after successful stash wait
- Clarify multi-tenant safety of sweep in docstring (12h TTL +
  pattern match ensures active sessions are never affected)
2026-03-14 23:44:25 +07:00
Zamil Majdy
69db0815c3 fix(backend/copilot): add defence-in-depth realpath check in is_allowed_local_path
Resolve project_dir via os.path.realpath and validate it stays within
SDK_PROJECTS_DIR before checking the resolved path. Guards against
potential future bugs in _encode_cwd_for_cli, matching the pattern
already used in transcript.py.
2026-03-14 23:42:13 +07:00
Zamil Majdy
775ed85bba fix(backend/copilot): sanitize error paths, add cleanup sweep, and harden file handling 2026-03-14 23:40:19 +07:00
Zamil Majdy
f07bb52ac3 fix: correct tool name guidance, UUID comment, and docstring path
- Support both read_file (E2B) and Read (non-E2B) in prompt guidance
- Fix UUID comment from "UUID-v4" to "UUID" (regex accepts all versions)
- Update security_hooks docstring to include UUID segment in path
2026-03-14 22:28:12 +07:00
Zamil Majdy
84482071a8 Merge remote-tracking branch 'origin/dev' into fix/copilot-tool-result-read 2026-03-14 22:11:16 +07:00
Zamil Majdy
9ac01a0cf6 fix(backend/copilot): harden tool-result reads, add disk sweep, remove dead code
- _read_local_tool_result: detect binary files (return raw base64 instead of
  corrupting with errors="replace"), add 10 MB size limit, move getsize inside
  try block, use consistent char units in messages
- Add cleanup_stale_project_dirs() to sweep CLI project dirs older than 6h,
  preventing unbounded disk growth from per-turn directory creation
- Add re.IGNORECASE to _UUID_RE for robust UUID matching
- Add TOCTOU acknowledgment to _check_sandbox_symlink_escape docstring
- Clarify transcript_path sanitization comment in security_hooks.py
- Remove dead code: read_cli_session_file, cleanup_cli_project_dir, _cli_project_dir,
  _safe_glob_jsonl (no remaining callers after cleanup changes)
- Add tests: TestReadLocalToolResult (6 cases), TestCleanupStaleProjectDirs (2 cases)
2026-03-14 22:09:59 +07:00
Zamil Majdy
71337c0514 Merge origin/dev into fix/copilot-tool-result-read
Resolve conflict in transcript.py by accepting new functions from dev
(_projects_base, _cli_project_dir, _safe_glob_jsonl,
read_compacted_entries, read_cli_session_file, cleanup_cli_project_dir).
2026-03-14 10:21:54 +07:00
Zamil Majdy
b2808f223a fix(backend/copilot): address review comments — text seek bug, symlink helper, cleanup simplification
- Fix invalid fh.seek() on text-mode file in _read_local_tool_result by
  reading full content and slicing (sentry bot bug report)
- Extract symlink escape check into _check_sandbox_symlink_escape helper
- Remove over-engineered TTL sweep of project dirs; just clean tmp dir
2026-03-13 22:07:33 +07:00
Zamil Majdy
85101bfc5b fix(backend/copilot): address third-bump review comments
- Add defence-in-depth is_allowed_local_path check in _read_local_tool_result
- Scope _sweep_stale_project_dirs to current session's encoded_dir only
- Remove dead cleanup_cli_project_dir from transcript.py
- Check readlink exit_code in e2b_file_tools symlink validation
- Remove redundant try/except around shutil.rmtree(ignore_errors=True)
- Add test for parts[1] != "tool-results" rejection path
- Rename _SDK_PROJECTS_DIR to SDK_PROJECTS_DIR (public API)
- Remove sleep(0) band-aid from wait_for_stash, add timeout justification
- Extract _UUID_RE compiled constant for conversation UUID validation
2026-03-13 19:54:00 +07:00
Zamil Majdy
3334a4b4b5 Merge remote-tracking branch 'origin/dev' into fix/copilot-tool-result-read 2026-03-13 19:27:06 +07:00
Zamil Majdy
796e737d77 fix(backend/copilot): address reviewer comments on tool-result PR
- Move local imports (time, _SDK_PROJECTS_DIR) to top-level in service.py
- Add UUID format regex validation for path segments in context.py
- Extract _latest_mtime helper to reduce nesting in _sweep_stale_project_dirs
- Use mimetypes.guess_type() instead of hardcoded mime_type in workspace_files.py
- Update test UUIDs to match the new strict UUID regex validation
2026-03-13 17:51:07 +07:00
Zamil Majdy
8d16f8052b fix(backend/copilot): ensure stream lock release even if cleanup fails
Wrap _cleanup_sdk_tool_results in try/finally so lock.release() is
always called, preventing session deadlocks on cleanup exceptions.
2026-03-13 16:32:40 +07:00
Zamil Majdy
1f8ab0687c fix(backend/copilot): offload sync cleanup to thread to avoid blocking event loop
Move filesystem IO in _cleanup_sdk_tool_results (shutil.rmtree and
_sweep_stale_project_dirs) to asyncio.to_thread so the async stream
generator's finally block doesn't block the event loop during cleanup.
2026-03-13 16:20:09 +07:00
Zamil Majdy
035aba9cf1 fix(backend/copilot): address PR review — mtime staleness and symlink escape
- Use max mtime across conv dir and immediate children (tool-results/)
  to avoid premature cleanup of active sessions whose directory mtime
  hasn't updated (addresses sentry bot review)
- Replace normpath-based re-validation with readlink -f inside the E2B
  sandbox to properly detect symlink escapes after mkdir (addresses
  coderabbit review)
2026-03-13 16:04:22 +07:00
Zamil Majdy
e0128470a9 fix(backend/copilot): harden tool-result path validation and address review feedback
- Tighten is_allowed_local_path to only allow UUID-nested tool-results
  paths (<encoded-cwd>/<uuid>/tool-results/<file>), rejecting the
  non-UUID pattern that isn't a real SDK flow
- Add TTL-based cleanup (24h) for stale conversation UUID dirs under
  ~/.claude/projects/ to prevent disk leak (addresses sentry bot review)
- Add path re-validation after mkdir in E2B write handler to prevent
  symlink escape
- Increase wait_for_stash timeout from 0.5s to 2.0s and add post-timeout
  retry to reduce PostToolUse hook race condition output loss
- Update all affected tests to use UUID-nested path pattern
2026-03-13 15:50:17 +07:00
Zamil Majdy
a4deae0f69 fix(backend/copilot): fix tool-result file read failures across turns
Three bugs caused "file not found" errors when the model tried to read
SDK tool-result files:

1. Path validation mismatch: is_allowed_local_path() expected
   tool-results directly under the project dir, but the SDK nests them
   under a conversation UUID subdirectory. Fixed to match any
   tool-results/ segment within the project dir.

2. Wrong tool fallback: when the model mistakenly called
   read_workspace_file (cloud storage) for SDK tool-result paths on
   local disk, it got "file not found". Added a fallback in
   ReadWorkspaceFileTool that detects allowed local paths and reads
   from disk instead.

3. Cross-turn cleanup: _cleanup_sdk_tool_results deleted the entire
   CLI project directory (including tool-results/) between turns.
   Subsequent turns referencing those paths via --resume transcript
   would fail. Removed the project dir cleanup — only the temp cwd
   is cleaned now.

Also added system prompt guidance telling the model to use read_file
(not read_workspace_file) for SDK tool-result paths.
2026-03-13 15:33:30 +07:00
25 changed files with 1539 additions and 3596 deletions

View File

@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=420, # 7 min safety net — allows headroom for compaction retries
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",

View File

@@ -17,8 +17,17 @@ from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
# Compiled UUID pattern for validating conversation directory names.
# Kept as a module-level constant so the security-relevant pattern is easy
# to audit in one place and avoids recompilation on every call.
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
@@ -35,11 +44,20 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
def encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The Claude CLI encodes the absolute cwd as a directory name by replacing
every non-alphanumeric character with ``-``. For example
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
# Keep the private alias for internal callers (backwards compat).
_encode_cwd_for_cli = encode_cwd_for_cli
def set_execution_context(
user_id: str | None,
session: ChatSession,
@@ -100,7 +118,9 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
"""
if not path:
return False
@@ -119,10 +139,22 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
encoded = _current_project_dir.get("")
if encoded:
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and tool-results/.
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] == "tool-results"
):
return True
return False

View File

@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
_SDK_PROJECTS_DIR,
SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
@@ -104,11 +104,13 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
def test_is_allowed_local_path_tool_results_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
encoded = "test-encoded-dir"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
)
_current_project_dir.set(encoded)
try:
@@ -117,10 +119,22 @@ def test_is_allowed_local_path_tool_results_dir():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
@@ -129,6 +143,21 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
_current_project_dir.set("")
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
)
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------

View File

@@ -152,6 +152,13 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `read_file` or `Read` (NOT `read_workspace_file`).
`read_workspace_file` reads from cloud workspace storage, where SDK
tool-results are NOT stored.
{_SHARED_TOOL_NOTES}"""

View File

@@ -43,7 +43,6 @@ class ResponseType(str, Enum):
ERROR = "error"
USAGE = "usage"
HEARTBEAT = "heartbeat"
STATUS = "status"
class StreamBaseResponse(BaseModel):
@@ -233,26 +232,3 @@ class StreamHeartbeat(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE comment format to keep connection alive."""
return ": heartbeat\n\n"
class StreamStatus(StreamBaseResponse):
"""Transient status notification shown to the user during long operations.
Used to provide feedback when the backend performs behind-the-scenes work
(e.g., compacting conversation context on a retry) that would otherwise
leave the user staring at an unexplained pause.
"""
type: ResponseType = ResponseType.STATUS
message: str = Field(..., description="Human-readable status message")
def to_sse(self) -> str:
"""Encode as an SSE comment so the AI SDK stream parser ignores it.
The frontend AI SDK validates every ``data:`` line against a strict
Zod union of known chunk types. ``"status"`` is not in that union,
so sending it as ``data:`` would cause a schema-validation error that
breaks the entire stream. Using an SSE comment (``:``) keeps the
connection alive and is silently discarded by ``EventSource`` parsers.
"""
return f": status {self.message}\n\n"

View File

@@ -12,7 +12,6 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
@@ -120,12 +119,14 @@ def filter_compaction_messages(
filtered: list[ChatMessage] = []
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
real_calls: list[dict[str, Any]] = []
for tc in msg.tool_calls:
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
compaction_ids.add(tc.get("id", ""))
else:
real_calls.append(tc)
real_calls = [
tc
for tc in msg.tool_calls
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
]
if not real_calls and not msg.content:
continue
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
@@ -221,7 +222,6 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -1,41 +0,0 @@
"""Shared test fixtures for copilot SDK tests."""
from __future__ import annotations
from uuid import uuid4
from backend.util import json
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
"""Build a minimal valid JSONL transcript from (role, content) pairs.
Use this helper in any copilot SDK test that needs a well-formed
transcript without hitting the real storage layer.
"""
lines: list[str] = []
last_uuid: str | None = None
for role, content in pairs:
uid = str(uuid4())
entry_type = "assistant" if role == "assistant" else "user"
msg: dict = {"role": role, "content": content}
if role == "assistant":
msg.update(
{
"model": "",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
)
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": msg,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n"

View File

@@ -26,6 +26,41 @@ from backend.copilot.context import (
logger = logging.getLogger(__name__)
async def _check_sandbox_symlink_escape(
sandbox: Any,
parent: str,
) -> str | None:
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
``readlink -f`` follows actual symlinks on the sandbox filesystem.
Returns the canonical parent path, or ``None`` if the path escapes
``E2B_WORKDIR``.
Note: There is an inherent TOCTOU window between this check and the
subsequent ``sandbox.files.write()``. A symlink could theoretically be
replaced between the two operations. This is acceptable in the E2B
sandbox model since the sandbox is single-user and ephemeral.
"""
canonical_res = await sandbox.commands.run(
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
cwd=E2B_WORKDIR,
timeout=5,
)
canonical_parent = (canonical_res.stdout or "").strip()
if (
canonical_res.exit_code != 0
or not canonical_parent
or (
canonical_parent != E2B_WORKDIR
and not canonical_parent.startswith(E2B_WORKDIR + "/")
)
):
return None
return canonical_parent
def _get_sandbox():
return get_current_sandbox()
@@ -106,6 +141,10 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
@@ -130,6 +169,12 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
return result
sandbox, remote = result
parent = os.path.dirname(remote)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")

View File

@@ -4,15 +4,19 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import shutil
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
from .e2b_file_tools import (
_check_sandbox_symlink_escape,
_read_local,
resolve_sandbox_path,
)
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
@@ -21,46 +25,48 @@ _SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
def test_absolute_within_sandbox(self):
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
assert (
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
)
def test_workdir_itself(self):
assert resolve_sandbox_path("/home/user") == "/home/user"
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
def test_relative_dotslash(self):
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert resolve_sandbox_path("src/") == "/home/user/src"
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
"""Path that resolves back within E2B_WORKDIR is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
# ---------------------------------------------------------------------------
@@ -73,9 +79,13 @@ class TestResolveSandboxPath:
class TestReadLocal:
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
@@ -107,7 +117,9 @@ class TestReadLocal:
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
@@ -117,7 +129,7 @@ class TestReadLocal:
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
os.rmdir(tool_results_dir)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
@@ -152,3 +164,66 @@ class TestReadLocal:
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True
# ---------------------------------------------------------------------------
# _check_sandbox_symlink_escape — symlink escape detection
# ---------------------------------------------------------------------------
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
return SimpleNamespace(commands=commands)
class TestCheckSandboxSymlinkEscape:
@pytest.mark.asyncio
async def test_canonical_path_within_workdir_returns_path(self):
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result == f"{E2B_WORKDIR}/src"
@pytest.mark.asyncio
async def test_workdir_itself_returns_workdir(self):
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
assert result == E2B_WORKDIR
@pytest.mark.asyncio
async def test_symlink_escape_returns_none(self):
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
assert result is None
@pytest.mark.asyncio
async def test_nonzero_exit_code_returns_none(self):
"""A non-zero exit code from readlink -f returns None."""
sandbox = _make_sandbox(stdout="", exit_code=1)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_empty_stdout_returns_none(self):
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
sandbox = _make_sandbox(stdout="", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_prefix_collision_returns_none(self):
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
assert result is None
@pytest.mark.asyncio
async def test_deeply_nested_path_within_workdir(self):
"""Deep nested paths inside E2B_WORKDIR are allowed."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
assert result == f"{E2B_WORKDIR}/a/b/c/d"

View File

@@ -1,552 +0,0 @@
"""Tests for retry logic and transcript compaction helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _is_prompt_too_long
from .transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
# ---------------------------------------------------------------------------
# _flatten_assistant_content
# ---------------------------------------------------------------------------
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: Read]" in result
def test_raw_strings(self):
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [
{"type": "text", "text": "See this image:"},
{"type": "image", "source": {"type": "base64", "data": "..."}},
]
result = _flatten_assistant_content(blocks)
assert "See this image:" in result
assert "[__image__]" in result
def test_empty(self):
assert _flatten_assistant_content([]) == ""
# ---------------------------------------------------------------------------
# _flatten_tool_result_content
# ---------------------------------------------------------------------------
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "123",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
assert _flatten_tool_result_content(blocks) == "ok"
def test_text_block(self):
blocks = [{"type": "text", "text": "plain text"}]
assert _flatten_tool_result_content(blocks) == "plain text"
def test_raw_string(self):
assert _flatten_tool_result_content(["raw"]) == "raw"
def test_tool_result_with_none_content(self):
"""tool_result with content=None should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
assert _flatten_tool_result_content(blocks) == ""
def test_tool_result_with_empty_list_content(self):
"""tool_result with content=[] should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
assert _flatten_tool_result_content(blocks) == ""
def test_empty(self):
assert _flatten_tool_result_content([]) == ""
def test_nested_dict_without_text(self):
"""Dict blocks without text key use json.dumps fallback."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "x",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
result = _flatten_tool_result_content(blocks)
assert "[__image__]" in result
# ---------------------------------------------------------------------------
# _transcript_to_messages
# ---------------------------------------------------------------------------
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
"""Build a JSONL line for testing."""
uid = str(uuid4())
msg: dict = {"role": role, "content": content}
msg.update(kwargs)
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": None,
"message": msg,
}
return json.dumps(entry, separators=(",", ":"))
class TestTranscriptToMessages:
def test_basic_roundtrip(self):
lines = [
_make_entry("user", "user", "Hello"),
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "Hello"}
assert messages[1] == {"role": "assistant", "content": "Hi"}
def test_skips_strippable_types(self):
"""Progress and metadata entries are excluded."""
lines = [
_make_entry("user", "user", "Hello"),
json.dumps(
{
"type": "progress",
"uuid": str(uuid4()),
"parentUuid": None,
"message": {"role": "assistant", "content": "..."},
}
),
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_empty_content(self):
assert _transcript_to_messages("") == []
def test_tool_result_content(self):
"""User entries with tool_result content blocks are flattened."""
lines = [
_make_entry(
"user",
"user",
[
{
"type": "tool_result",
"tool_use_id": "123",
"content": "tool output",
}
],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 1
assert messages[0]["content"] == "tool output"
def test_malformed_json_lines_skipped(self):
"""Malformed JSON lines in transcript are silently skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"this is not valid json",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_empty_lines_skipped(self):
"""Empty lines and whitespace-only lines are skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"",
" ",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_unicode_content_preserved(self):
"""Unicode characters survive transcript roundtrip."""
lines = [
_make_entry("user", "user", "Hello 你好 🌍"),
_make_entry(
"assistant",
"assistant",
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert messages[0]["content"] == "Hello 你好 🌍"
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
def test_entry_without_role_skipped(self):
"""Entries with missing role in message are skipped."""
entry_no_role = json.dumps(
{
"type": "user",
"uuid": str(uuid4()),
"parentUuid": None,
"message": {"content": "no role here"},
}
)
lines = [
entry_no_role,
_make_entry("user", "user", "Hello"),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 1
assert messages[0]["content"] == "Hello"
def test_tool_use_and_result_pairs(self):
"""Tool use + tool result pairs are properly flattened."""
lines = [
_make_entry(
"assistant",
"assistant",
[
{"type": "text", "text": "Let me check."},
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
],
),
_make_entry(
"user",
"user",
[
{
"type": "tool_result",
"tool_use_id": "abc",
"content": [{"type": "text", "text": "file contents"}],
}
],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert "Let me check." in messages[0]["content"]
assert "[tool_use: read_file]" in messages[0]["content"]
assert messages[1]["content"] == "file contents"
# ---------------------------------------------------------------------------
# _messages_to_transcript
# ---------------------------------------------------------------------------
class TestMessagesToTranscript:
def test_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _messages_to_transcript(messages)
lines = result.strip().split("\n")
assert len(lines) == 2
for line in lines:
parsed = json.loads(line)
assert "type" in parsed
assert "uuid" in parsed
assert "message" in parsed
def test_assistant_has_proper_structure(self):
messages = [{"role": "assistant", "content": "Hello"}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["type"] == "assistant"
msg = entry["message"]
assert msg["role"] == "assistant"
assert msg["type"] == "message"
assert msg["stop_reason"] == "end_turn"
assert isinstance(msg["content"], list)
assert msg["content"][0]["type"] == "text"
def test_user_has_plain_content(self):
messages = [{"role": "user", "content": "Hi"}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["type"] == "user"
assert entry["message"]["content"] == "Hi"
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "A"},
{"role": "assistant", "content": "B"},
{"role": "user", "content": "C"},
]
result = _messages_to_transcript(messages)
lines = result.strip().split("\n")
entries = [json.loads(line) for line in lines]
assert entries[0]["parentUuid"] == ""
assert entries[1]["parentUuid"] == entries[0]["uuid"]
assert entries[2]["parentUuid"] == entries[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_output_is_valid_transcript(self):
"""Output should pass validate_transcript if it has assistant entries."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
def test_roundtrip_to_messages(self):
"""Messages → transcript → messages preserves structure."""
original = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]
transcript = _messages_to_transcript(original)
restored = _transcript_to_messages(transcript)
assert len(restored) == len(original)
for orig, rest in zip(original, restored):
assert orig["role"] == rest["role"]
assert orig["content"] == rest["content"]
# ---------------------------------------------------------------------------
# compact_transcript
# ---------------------------------------------------------------------------
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""compact_transcript returns None when transcript has < 2 messages."""
transcript = _build_transcript([("user", "Hello")])
with patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_not_compacted(self):
"""When compress_context says no compaction needed, returns None.
The compressor couldn't reduce it, so retrying with the same
content would fail identically."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi there"),
]
)
mock_result = type(
"CompressResult",
(),
{
"was_compacted": False,
"messages": [],
"original_token_count": 100,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
async def test_returns_compacted_transcript(self):
"""When compaction succeeds, returns a valid compacted transcript."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi"),
("user", "More"),
("assistant", "Details"),
]
)
compacted_msgs = [
{"role": "user", "content": "[summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 500,
"token_count": 100,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[1]["content"] == "Summarized response"
@pytest.mark.asyncio
async def test_returns_none_on_compression_failure(self):
"""When _run_compression raises, returns None."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi"),
]
)
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("LLM unavailable"),
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
# ---------------------------------------------------------------------------
# _is_prompt_too_long
# ---------------------------------------------------------------------------
class TestIsPromptTooLong:
"""Unit tests for _is_prompt_too_long pattern matching."""
def test_prompt_is_too_long(self):
err = RuntimeError("prompt is too long for model context")
assert _is_prompt_too_long(err) is True
def test_request_too_large(self):
err = Exception("request too large: 250000 tokens")
assert _is_prompt_too_long(err) is True
def test_maximum_context_length(self):
err = ValueError("maximum context length exceeded")
assert _is_prompt_too_long(err) is True
def test_context_length_exceeded(self):
err = Exception("context_length_exceeded")
assert _is_prompt_too_long(err) is True
def test_input_tokens_exceed(self):
err = Exception("input tokens exceed the max_tokens limit")
assert _is_prompt_too_long(err) is True
def test_input_is_too_long(self):
err = Exception("input is too long for the model")
assert _is_prompt_too_long(err) is True
def test_content_length_exceeds(self):
err = Exception("content length exceeds maximum")
assert _is_prompt_too_long(err) is True
def test_unrelated_error_returns_false(self):
err = RuntimeError("network timeout")
assert _is_prompt_too_long(err) is False
def test_auth_error_returns_false(self):
err = Exception("authentication failed: invalid API key")
assert _is_prompt_too_long(err) is False
def test_chained_exception_detected(self):
"""Prompt-too-long error wrapped in another exception is detected."""
inner = RuntimeError("prompt is too long")
outer = Exception("SDK error")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_case_insensitive(self):
err = Exception("PROMPT IS TOO LONG")
assert _is_prompt_too_long(err) is True
def test_old_max_tokens_exceeded_not_matched(self):
"""The old broad 'max_tokens_exceeded' pattern was removed.
Only 'input tokens exceed' should match now."""
err = Exception("max_tokens_exceeded")
assert _is_prompt_too_long(err) is False

View File

@@ -226,7 +226,7 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses

View File

@@ -42,7 +42,7 @@ def _validate_workspace_path(
Delegates to :func:`is_allowed_local_path` which permits:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning("Blocked tool access attempt: %s", tool_name)
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -89,9 +89,7 @@ def _validate_tool_access(
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
"Blocked dangerous pattern in tool input: %s in %s",
pattern,
tool_name,
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny(
"[SECURITY] Input contains a blocked pattern. "
@@ -113,9 +111,7 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -174,7 +170,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info("[SDK] Blocked background Task, user=%s", user_id)
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
@@ -185,9 +181,7 @@ def create_security_hooks(
)
if len(task_tool_use_ids) >= max_subtasks:
logger.warning(
"[SDK] Task limit reached (%d), user=%s",
max_subtasks,
user_id,
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
@@ -218,7 +212,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
@@ -288,11 +282,8 @@ def create_security_hooks(
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
tool_name,
str(error).replace("\n", "").replace("\r", ""),
user_id,
tool_use_id,
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
_release_task_slot(tool_name, tool_use_id)
@@ -310,19 +301,20 @@ def create_security_hooks(
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
# Sanitize untrusted input before logging to prevent log injection
trigger = (
str(input_data.get("trigger", "auto"))
.replace("\n", "")
.replace("\r", "")
)
trigger = input_data.get("trigger", "auto")
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
trigger,
user_id,
transcript_path,

View File

@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:

File diff suppressed because it is too large Load Diff

View File

@@ -1,283 +0,0 @@
"""Unit tests for extracted service helpers.
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
and the ``ReducedContext`` named tuple.
"""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
from .conftest import build_test_transcript as _build_transcript
from .service import (
ReducedContext,
_is_prompt_too_long,
_iter_sdk_messages,
_reduce_context,
)
# ---------------------------------------------------------------------------
# _is_prompt_too_long
# ---------------------------------------------------------------------------
class TestIsPromptTooLong:
def test_direct_match(self) -> None:
assert _is_prompt_too_long(Exception("prompt is too long")) is True
def test_case_insensitive(self) -> None:
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
def test_no_match(self) -> None:
assert _is_prompt_too_long(Exception("network timeout")) is False
def test_request_too_large(self) -> None:
assert _is_prompt_too_long(Exception("request too large for model")) is True
def test_context_length_exceeded(self) -> None:
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
def test_max_tokens_exceeded_not_matched(self) -> None:
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
def test_max_tokens_config_error_no_match(self) -> None:
"""'max_tokens must be at least 1' should NOT match."""
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
def test_chained_cause(self) -> None:
inner = Exception("prompt is too long")
outer = RuntimeError("SDK error")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_chained_context(self) -> None:
inner = Exception("request too large")
outer = RuntimeError("wrapped")
outer.__context__ = inner
assert _is_prompt_too_long(outer) is True
def test_deep_chain(self) -> None:
bottom = Exception("maximum context length")
middle = RuntimeError("middle")
middle.__cause__ = bottom
top = ValueError("top")
top.__cause__ = middle
assert _is_prompt_too_long(top) is True
def test_chain_no_match(self) -> None:
inner = Exception("rate limit exceeded")
outer = RuntimeError("wrapped")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is False
def test_cycle_detection(self) -> None:
"""Exception chain with a cycle should not infinite-loop."""
a = Exception("error a")
b = Exception("error b")
a.__cause__ = b
b.__cause__ = a # cycle
assert _is_prompt_too_long(a) is False
def test_all_patterns(self) -> None:
patterns = [
"prompt is too long",
"request too large",
"maximum context length",
"context_length_exceeded",
"input tokens exceed",
"input is too long",
"content length exceeds",
]
for pattern in patterns:
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
# ---------------------------------------------------------------------------
# _reduce_context
# ---------------------------------------------------------------------------
class TestReduceContext:
@pytest.mark.asyncio
async def test_first_retry_compaction_success(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
with (
patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=compacted,
),
patch(
"backend.copilot.sdk.service.validate_transcript",
return_value=True,
),
patch(
"backend.copilot.sdk.service.write_transcript_to_tempfile",
return_value="/tmp/resume.jsonl",
),
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert isinstance(ctx, ReducedContext)
assert ctx.use_resume is True
assert ctx.resume_file == "/tmp/resume.jsonl"
assert ctx.transcript_lost is False
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_compaction_fails_drops_transcript(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
with patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=None,
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.use_resume is False
assert ctx.resume_file is None
assert ctx.transcript_lost is True
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_already_tried_compaction_skips(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
assert ctx.use_resume is False
assert ctx.transcript_lost is True
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_empty_transcript_drops(self) -> None:
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
assert ctx.use_resume is False
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_compaction_returns_same_content_drops(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
with patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=transcript, # same content
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_write_tempfile_fails_drops(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
with (
patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=compacted,
),
patch(
"backend.copilot.sdk.service.validate_transcript",
return_value=True,
),
patch(
"backend.copilot.sdk.service.write_transcript_to_tempfile",
return_value=None,
),
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.transcript_lost is True
# ---------------------------------------------------------------------------
# _iter_sdk_messages
# ---------------------------------------------------------------------------
class TestIterSdkMessages:
@pytest.mark.asyncio
async def test_yields_messages(self) -> None:
messages = ["msg1", "msg2", "msg3"]
client = AsyncMock()
async def _fake_receive() -> AsyncGenerator[str]:
for m in messages:
yield m
client.receive_response = _fake_receive
result = [msg async for msg in _iter_sdk_messages(client)]
assert result == messages
@pytest.mark.asyncio
async def test_heartbeat_on_timeout(self) -> None:
"""Yields None when asyncio.wait times out."""
client = AsyncMock()
received: list = []
async def _slow_receive() -> AsyncGenerator[str]:
await asyncio.sleep(100) # never completes
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
client.receive_response = _slow_receive
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
count = 0
async for msg in _iter_sdk_messages(client):
received.append(msg)
count += 1
if count >= 3:
break
assert all(m is None for m in received)
@pytest.mark.asyncio
async def test_exception_propagates(self) -> None:
client = AsyncMock()
async def _error_receive() -> AsyncGenerator[str]:
raise RuntimeError("SDK crash")
yield # pragma: no cover — unreachable, yield makes this an async generator
client.receive_response = _error_receive
with pytest.raises(RuntimeError, match="SDK crash"):
async for _ in _iter_sdk_messages(client):
pass
@pytest.mark.asyncio
async def test_task_cleanup_on_break(self) -> None:
"""Pending task is cancelled when generator is closed."""
client = AsyncMock()
async def _slow_receive() -> AsyncGenerator[str]:
yield "first"
await asyncio.sleep(100)
yield "second"
client.receive_response = _slow_receive
gen = _iter_sdk_messages(client)
first = await gen.__anext__()
assert first == "first"
await gen.aclose() # should cancel pending task cleanly

View File

@@ -288,3 +288,90 @@ class TestPromptSupplement:
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
# ---------------------------------------------------------------------------
# _cleanup_sdk_tool_results — orchestration + rate-limiting
# ---------------------------------------------------------------------------
class TestCleanupSdkToolResults:
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
_CWD_PREFIX = "/tmp/copilot-"
@pytest.mark.asyncio
async def test_removes_cwd_directory(self):
"""Cleanup removes the session working directory."""
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-cleanup-remove"
os.makedirs(cwd, exist_ok=True)
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
import backend.copilot.sdk.service as svc_mod
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
assert not os.path.exists(cwd)
@pytest.mark.asyncio
async def test_sweep_runs_when_interval_elapsed(self):
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-elapsed"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to a time far in the past
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_called_once()
@pytest.mark.asyncio
async def test_sweep_skipped_within_interval(self):
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
import time
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-ratelimit"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to now — interval not elapsed
svc_mod._last_sweep_time = time.time()
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_not_called()
@pytest.mark.asyncio
async def test_rejects_path_outside_prefix(self, tmp_path):
"""Cleanup rejects a cwd that does not start with the expected prefix."""
from .service import _cleanup_sdk_tool_results
evil_cwd = str(tmp_path / "evil-path")
os.makedirs(evil_cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
await _cleanup_sdk_tool_results(evil_cwd)
# Directory should NOT have been removed (rejected early)
assert os.path.exists(evil_cwd)
mock_sweep.assert_not_called()

View File

@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
event.set()
async def wait_for_stash(timeout: float = 0.5) -> bool:
async def wait_for_stash(timeout: float = 2.0) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 0.5) -> bool:
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net — normally the stash happens within
microseconds of yielding to the event loop.
The 2.0 s default was chosen based on production metrics: the original
0.5 s caused frequent timeouts under load (parallel tool calls, large
outputs). 2.0 s gives a comfortable margin while still failing fast
when the hook genuinely will not fire.
"""
event = _stash_event.get(None)
if event is None:
@@ -234,9 +234,7 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
@@ -287,7 +285,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
with open(resolved, encoding="utf-8", errors="replace") as f:
selected = list(itertools.islice(f, offset, offset + limit))
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.

View File

@@ -10,9 +10,6 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
@@ -20,12 +17,8 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
from backend.util import json
from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
logger = logging.getLogger(__name__)
@@ -106,14 +99,7 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
# seen_parents is local per-entry (not shared across iterations) so
# it can only detect cycles within a single ancestry walk, not across
# entries. This is intentional: each entry's parent chain is
# independent, and reusing a global set would incorrectly short-circuit
# valid re-use of the same UUID as a parent in different subtrees.
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -165,44 +151,110 @@ def _projects_base() -> str:
return os.path.realpath(os.path.join(config_dir, "projects"))
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
Returns ``None`` if the path would escape the projects base.
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
directory. These are intentionally kept across turns so the model can read
tool-result files via ``--resume``. However, after a session ends they
become stale. This function sweeps old ones to prevent unbounded disk
growth.
When *encoded_cwd* is provided the sweep is scoped to that single
directory, making the operation safe in multi-tenant environments where
multiple copilot sessions share the same host. Without it the function
falls back to sweeping all directories matching the copilot naming pattern
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
Returns the number of directories removed.
"""
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
projects_base = _projects_base()
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not os.path.isdir(projects_base):
return 0
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] Project dir escaped projects base: %s", project_dir
)
return None
return project_dir
now = time.time()
removed = 0
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
try:
resolved_base = Path(project_dir).resolve()
except OSError as e:
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
return []
result: list[Path] = []
for candidate in Path(project_dir).glob("*.jsonl"):
try:
resolved = candidate.resolve()
if resolved.is_relative_to(resolved_base):
result.append(resolved)
except (OSError, RuntimeError) as e:
logger.debug(
"[Transcript] Skipping invalid CLI session candidate %s: %s",
candidate,
e,
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
if "-tmp-copilot-" not in target.name:
logger.warning(
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
)
return result
return 0
try:
# st_mtime is used as a proxy for session activity. Claude CLI writes
# its JSONL transcript into this directory during each turn, so mtime
# advances on every turn. A directory whose mtime is older than
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
# and is safe to remove (the session cannot --resume after cleanup).
age = now - target.stat().st_mtime
except OSError:
return 0
if age < _STALE_PROJECT_DIR_SECONDS:
return 0
try:
shutil.rmtree(target, ignore_errors=True)
removed = 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
target.name,
int(age),
_STALE_PROJECT_DIR_SECONDS,
)
return removed
# Unscoped fallback: sweep all copilot dirs across the projects base.
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
for entry in entries:
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
break
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
# -private-tmp-copilot-).
if "-tmp-copilot-" not in entry.name:
continue
if not entry.is_dir():
continue
try:
# See the scoped-mode comment above: st_mtime advances on every turn,
# so a stale mtime reliably indicates an inactive session.
age = now - entry.stat().st_mtime
except OSError:
continue
if age < _STALE_PROJECT_DIR_SECONDS:
continue
try:
shutil.rmtree(entry, ignore_errors=True)
removed += 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
removed,
_STALE_PROJECT_DIR_SECONDS,
)
return removed
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
@@ -269,63 +321,6 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
return entries
def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any compaction.
The CLI writes its session transcript to
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
Since each SDK turn uses a unique ``sdk_cwd``, there should be
exactly one ``.jsonl`` file in that directory.
Returns the file content, or ``None`` if not found.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = _safe_glob_jsonl(project_dir)
if not jsonl_files:
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
return None
# Pick the most recently modified file (should be only one per turn).
try:
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
except OSError as e:
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
return None
try:
content = session_file.read_text()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
session_file,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
@@ -341,7 +336,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
@@ -351,17 +346,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning("[Transcript] Failed to write resume file: %s", e)
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
@@ -422,6 +417,8 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = parts
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
@@ -460,15 +457,17 @@ async def upload_transcript(
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
"""
from backend.util.workspace_storage import get_workspace_storage
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
for line in stripped.strip().split("\n")
]
entry_types: list[str] = []
for line in stripped.strip().split("\n"):
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
entry_types.append(entry.get("type", "?"))
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
@@ -504,14 +503,11 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
)
@@ -525,6 +521,8 @@ async def download_transcript(
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
@@ -532,10 +530,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug("%s No transcript in storage", log_prefix)
logger.debug(f"{log_prefix} No transcript in storage")
return None
except Exception as e:
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -547,14 +545,10 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except FileNotFoundError:
except (FileNotFoundError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -568,6 +562,8 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
@@ -584,280 +580,3 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
# ---------------------------------------------------------------------------
# Transcript compaction — LLM summarization for prompt-too-long recovery
# ---------------------------------------------------------------------------
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
placeholders. This is intentional: ``compress_context`` requires plain
text for token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript was
already too large for the model — a summarized plain-text version is
better than no context at all.
"""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
btype = block.get("type", "")
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
else:
# Preserve non-text blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
parts.append(f"[__{btype}__]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
Like ``_flatten_assistant_content``, structured blocks (images, nested
tool results) are reduced to text representations for compression.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content") or ""
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
sub_type = sub.get("type")
if sub_type in ("image", "document"):
# Avoid serializing base64 binary data into
# the compaction input — use a placeholder.
str_parts.append(f"[__{sub_type}__]")
elif sub_type == "text" or sub.get("text") is not None:
str_parts.append(str(sub.get("text", "")))
else:
str_parts.append(json.dumps(sub))
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, dict):
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
btype = block.get("type", "unknown")
str_parts.append(f"[__{btype}__]")
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to plain message dicts for compression.
Parses each line of the JSONL *content*, skips strippable metadata entries
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
flattened ``content`` from the ``message`` field of each remaining entry.
Structured content blocks (``tool_use``, ``tool_result``, images) are
flattened to plain text via ``_flatten_assistant_content`` and
``_flatten_tool_result_content`` so that ``compress_context`` can
perform token counting and LLM summarization on uniform strings.
Returns:
A list of ``{"role": str, "content": str}`` dicts suitable for
``compress_context``.
"""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format.
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
dicts returned by ``compress_context``. Each message becomes one JSONL
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
``--resume`` flag can reconstruct a valid conversation tree.
Assistant messages are wrapped in the full ``message`` envelope
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
that the CLI expects. User messages use the simpler ``{role, content}``
form.
Returns:
A newline-terminated JSONL string, or an empty string if *messages*
is empty.
"""
lines: list[str] = []
last_uuid: str = "" # root entry uses empty string, not null
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
_COMPACTION_TIMEOUT_SECONDS = 60
_TRUNCATION_TIMEOUT_SECONDS = 30
async def _run_compression(
messages: list[dict],
model: str,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback.
Uses the shared OpenAI client from ``get_openai_client()``.
If no client is configured or the LLM call fails, falls back to
truncation-based compression which drops older messages without
summarization.
A 60-second timeout prevents a hung LLM call from blocking the
retry path indefinitely. The truncation fallback also has a
30-second timeout to guard against slow tokenization on very large
transcripts.
"""
client = get_openai_client()
if client is None:
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
try:
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=client),
timeout=_COMPACTION_TIMEOUT_SECONDS,
)
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
async def compact_transcript(
content: str,
*,
model: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
is flattened to plain text for compression. This matches the fidelity of
the Plan C (DB compression) fallback path, where
``_format_conversation_context`` similarly renders tool calls as
``You called tool: name(args)`` and results as ``Tool result: ...``.
Neither path preserves structured API content blocks — the compacted
context serves as text history for the LLM, which creates proper
structured tool calls going forward.
Images are per-turn attachments loaded from workspace storage by file ID
(via ``_prepare_file_attachments``), not part of the conversation history.
They are re-attached each turn and are unaffected by compaction.
Returns the compacted JSONL string, or ``None`` on failure.
See also:
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
lists for pre-query DB history. Both share ``compress_context()``
but operate on different input formats (JSONL transcript entries
here vs. ChatMessage dicts there).
"""
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, model, log_prefix)
if not result.was_compacted:
# Compressor says it's within budget, but the SDK rejected it.
# Return None so the caller falls through to DB fallback.
logger.warning(
"%s Compressor reports within budget but SDK rejected — "
"signalling failure",
log_prefix,
)
return None
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None

View File

@@ -68,7 +68,7 @@ class TranscriptBuilder:
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
isCompactSummary=data.get("isCompactSummary"),
isCompactSummary=data.get("isCompactSummary") or None,
message=data.get("message", {}),
)

View File

@@ -1,7 +1,7 @@
"""Unit tests for JSONL transcript management utilities."""
import os
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
@@ -9,9 +9,7 @@ from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
_cli_project_dir,
delete_transcript,
read_cli_session_file,
read_compacted_entries,
strip_progress_entries,
validate_transcript,
@@ -292,85 +290,6 @@ class TestStripProgressEntries:
assert asst_entry["parentUuid"] == "u1" # reparented
# --- read_cli_session_file ---
class TestReadCliSessionFile:
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
"""read_cli_session_file returns None when no .jsonl files exist."""
# Create a project dir with no jsonl files
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
assert read_cli_session_file("/fake/cwd") is None
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
"""read_cli_session_file returns the content of a single .jsonl file."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
jsonl_file = project_dir / "session.jsonl"
jsonl_file.write_text("line1\nline2\n")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
result = read_cli_session_file("/fake/cwd")
assert result == "line1\nline2\n"
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
"""read_cli_session_file skips symlinks that escape the project dir."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
# Create a file outside the project dir
outside = tmp_path / "outside"
outside.mkdir()
outside_file = outside / "evil.jsonl"
outside_file.write_text("should not be read\n")
# Symlink from inside project_dir to outside file
symlink = project_dir / "evil.jsonl"
symlink.symlink_to(outside_file)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
# The symlink target resolves outside project_dir, so it should be skipped
result = read_cli_session_file("/fake/cwd")
assert result is None
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
config_dir = tmp_path / "config"
config_dir.mkdir()
projects_dir = config_dir / "projects"
projects_dir.mkdir()
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# Create a symlink inside projects/ that points outside of it.
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
# cwd whose encoded form matches the symlink name we create.
evil_target = tmp_path / "escaped"
evil_target.mkdir()
# The encoded form of "/evil/cwd" is "-evil-cwd"
symlink_path = projects_dir / "-evil-cwd"
symlink_path.symlink_to(evil_target)
result = _cli_project_dir("/evil/cwd")
assert result is None
# --- delete_transcript ---
@@ -382,7 +301,7 @@ class TestDeleteTranscript:
mock_storage.delete = AsyncMock()
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -402,7 +321,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -420,7 +339,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.util.workspace_storage.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -900,131 +819,206 @@ class TestCompactionFlowIntegration:
# ---------------------------------------------------------------------------
# _run_compression (direct tests for the 3 code paths)
# cleanup_stale_project_dirs
# ---------------------------------------------------------------------------
class TestRunCompression:
"""Direct tests for ``_run_compression`` covering all 3 code paths.
class TestCleanupStaleProjectDirs:
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
Paths:
(a) No OpenAI client configured → truncation fallback immediately.
(b) LLM success → returns LLM-compressed result.
(c) LLM call raises → truncation fallback.
"""
def _make_compress_result(self, was_compacted: bool, msgs=None):
"""Build a minimal CompressResult-like object."""
from types import SimpleNamespace
return SimpleNamespace(
was_compacted=was_compacted,
messages=msgs or [{"role": "user", "content": "summary"}],
original_token_count=500,
token_count=100 if was_compacted else 500,
messages_summarized=2 if was_compacted else 0,
messages_dropped=0,
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories matching copilot pattern older than threshold are removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@pytest.mark.asyncio
async def test_no_client_uses_truncation(self):
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
from .transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated"}]
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=None,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
new_callable=AsyncMock,
return_value=truncation_result,
) as mock_compress,
):
result = await _run_compression(
[{"role": "user", "content": "hello"}],
model="test-model",
log_prefix="[test]",
)
# Create a stale dir
stale = projects_dir / "-tmp-copilot-old-session"
stale.mkdir()
# Set mtime to past the threshold
import time
# compress_context called with client=None (truncation mode)
call_kwargs = mock_compress.call_args
assert (
call_kwargs.kwargs.get("client") is None
or (call_kwargs.args and call_kwargs.args[2] is None)
or mock_compress.call_args[1].get("client") is None
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale, (old_time, old_time))
# Create a fresh dir
fresh = projects_dir / "-tmp-copilot-new-session"
fresh.mkdir()
removed = cleanup_stale_project_dirs()
assert removed == 1
assert not stale.exists()
assert fresh.exists()
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories not matching copilot pattern are left alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
assert result is truncation_result
@pytest.mark.asyncio
async def test_llm_success_returns_llm_result(self):
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
from .transcript import _run_compression
# Non-copilot dir that's old
import time
llm_result = self._make_compress_result(
True, [{"role": "user", "content": "LLM summary"}]
other = projects_dir / "some-other-project"
other.mkdir()
old_time = time.time() - 999999
os.utime(other, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert other.exists()
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
"""A directory exactly at the TTL boundary should NOT be removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
mock_client = MagicMock()
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
new_callable=AsyncMock,
return_value=llm_result,
) as mock_compress,
):
result = await _run_compression(
[{"role": "user", "content": "long conversation"}],
model="test-model",
log_prefix="[test]",
)
# compress_context called with the real client
assert mock_compress.called
assert result is llm_result
@pytest.mark.asyncio
async def test_llm_failure_falls_back_to_truncation(self):
"""Path (c): LLM call raises → truncation fallback used instead."""
from .transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated fallback"}]
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
mock_client = MagicMock()
call_count = [0]
async def _compress_side_effect(**kwargs):
call_count[0] += 1
if kwargs.get("client") is not None:
raise RuntimeError("LLM timeout")
return truncation_result
import time
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
side_effect=_compress_side_effect,
),
):
result = await _run_compression(
[{"role": "user", "content": "long conversation"}],
model="test-model",
log_prefix="[test]",
)
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
boundary = projects_dir / "-tmp-copilot-boundary"
boundary.mkdir()
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
os.utime(boundary, (boundary_time, boundary_time))
# compress_context called twice: once for LLM (raises), once for truncation
assert call_count[0] == 2
assert result is truncation_result
removed = cleanup_stale_project_dirs()
assert removed == 0
assert boundary.exists()
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
"""Regular files matching the copilot pattern are not removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
import time
# Create a regular FILE (not a dir) with the copilot pattern name
stale_file = projects_dir / "-tmp-copilot-stale-file"
stale_file.write_text("not a dir")
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale_file, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert stale_file.exists()
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
"""If the projects base directory doesn't exist, return 0 gracefully."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: nonexistent,
)
removed = cleanup_stale_project_dirs()
assert removed == 0
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
"""When encoded_cwd is supplied only that directory is swept."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
# Two stale copilot dirs
target = projects_dir / "-tmp-copilot-session-abc"
target.mkdir()
os.utime(target, (old_time, old_time))
other = projects_dir / "-tmp-copilot-session-xyz"
other.mkdir()
os.utime(other, (old_time, old_time))
# Only the target dir should be removed
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
assert removed == 1
assert not target.exists()
assert other.exists() # untouched — not the current session
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep leaves a fresh directory alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
fresh = projects_dir / "-tmp-copilot-session-new"
fresh.mkdir()
# mtime is now — well within TTL
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
assert removed == 0
assert fresh.exists()
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep refuses to remove a non-copilot directory."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
non_copilot = projects_dir / "some-other-project"
non_copilot.mkdir()
os.utime(non_copilot, (old_time, old_time))
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
assert removed == 0
assert non_copilot.exists()

View File

@@ -41,7 +41,8 @@ import contextlib
import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox, SandboxLifecycle
from e2b import AsyncSandbox
from e2b.sandbox.sandbox_api import SandboxLifecycle
from backend.data.redis_client import get_redis_async

View File

@@ -2,6 +2,7 @@
import base64
import logging
import mimetypes
import os
from typing import Any, Optional
@@ -10,7 +11,9 @@ from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
@@ -24,6 +27,10 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
# Sentinel file_id used when a tool-result file is read directly from the local
# host filesystem (rather than from workspace storage).
_LOCAL_TOOL_RESULT_FILE_ID = "local"
async def _resolve_write_content(
content_text: str | None,
@@ -275,6 +282,93 @@ class WorkspaceFileContentResponse(ToolResponseBase):
content_base64: str
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
def _read_local_tool_result(
path: str,
char_offset: int,
char_length: Optional[int],
session_id: str,
sdk_cwd: str | None = None,
) -> ToolResponseBase:
"""Read an SDK tool-result file from local disk.
This is a fallback for when the model mistakenly calls
``read_workspace_file`` with an SDK tool-result path that only exists on
the host filesystem, not in cloud workspace storage.
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
regardless of what the caller has already checked.
"""
# TOCTOU: path validated then opened separately. Acceptable because
# the tool-results directory is server-controlled, not user-writable.
expanded = os.path.realpath(os.path.expanduser(path))
# Defence-in-depth: re-check with resolved path (caller checked raw path).
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
return ErrorResponse(
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
)
try:
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
# Pre-read size check prevents loading files far above the cap;
# the remaining TOCTOU gap is acceptable for server-controlled paths.
file_size = os.path.getsize(expanded)
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
return ErrorResponse(
message=(f"File too large: {os.path.basename(path)}"),
session_id=session_id,
)
# Detect binary files: try strict UTF-8 first, fall back to
# base64-encoding the raw bytes for binary content.
with open(expanded, "rb") as fh:
raw = fh.read()
try:
text_content = raw.decode("utf-8")
except UnicodeDecodeError:
# Binary file — return raw base64, ignore char_offset/char_length
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
content_base64=base64.b64encode(raw).decode("ascii"),
message=(
f"Read {file_size:,} bytes (binary) from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
end = (
char_offset + char_length if char_length is not None else len(text_content)
)
slice_text = text_content[char_offset:end]
except FileNotFoundError:
return ErrorResponse(
message=f"File not found: {os.path.basename(path)}", session_id=session_id
)
except Exception as exc:
return ErrorResponse(
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
)
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
message=(
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
f"of {len(text_content):,} chars from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
@@ -533,6 +627,14 @@ class ReadWorkspaceFileTool(BaseTool):
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
# Fallback: if the path is an SDK tool-result on local disk,
# read it directly instead of failing. The model sometimes
# calls read_workspace_file for these paths by mistake.
sdk_cwd = get_sdk_cwd()
if path and is_allowed_local_path(path, sdk_cwd):
return _read_local_tool_result(
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
)
return resolved
target_file_id, file_info = resolved

View File

@@ -2,18 +2,25 @@
import base64
import os
import shutil
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
from backend.copilot.tools._test_data import make_session, setup_test_data
from backend.copilot.tools.models import ErrorResponse
from backend.copilot.tools.workspace_files import (
_MAX_LOCAL_TOOL_RESULT_BYTES,
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WorkspaceDeleteResponse,
WorkspaceFileContentResponse,
WorkspaceFileListResponse,
WorkspaceWriteResponse,
WriteWorkspaceFileTool,
_read_local_tool_result,
_resolve_write_content,
_validate_ephemeral_path,
)
@@ -325,3 +332,294 @@ async def test_write_workspace_file_source_path(setup_test_data):
await delete_tool._execute(
user_id=user.id, session=session, file_id=write_resp.file_id
)
# ---------------------------------------------------------------------------
# _read_local_tool_result — local disk fallback for SDK tool-result files
# ---------------------------------------------------------------------------
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
class TestReadLocalToolResult:
"""Tests for _read_local_tool_result (local disk fallback)."""
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
"""Create a tool-results file and return its path."""
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, filename)
with open(filepath, "wb") as f:
f.write(content)
return filepath
def _cleanup(self, encoded: str) -> None:
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_text_file(self):
"""Read a UTF-8 text tool-result file."""
encoded = "-tmp-copilot-local-read-text"
path = self._make_tool_result(encoded, "output.txt", b"hello world")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "hello world"
assert "text/plain" in result.mime_type
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_text_with_offset(self):
"""Read a slice of a text file using char_offset and char_length."""
encoded = "-tmp-copilot-local-read-offset"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 3, 4, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "DEFG"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_binary_file(self):
"""Binary files are returned as raw base64."""
encoded = "-tmp-copilot-local-read-binary"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "image.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64)
assert decoded == binary_data
assert "binary" in result.message
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_disallowed_path_rejected(self):
"""Paths not under allowed directories are rejected."""
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not allowed" in result.message.lower()
def test_file_not_found(self):
"""Missing files return an error."""
encoded = "-tmp-copilot-local-read-missing"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
path = os.path.join(tool_dir, "nope.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not found" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_file_too_large(self, monkeypatch):
"""Files exceeding the size limit are rejected."""
encoded = "-tmp-copilot-local-read-large"
# Create a small file but fake os.path.getsize to return a huge value
path = self._make_tool_result(encoded, "big.txt", b"small")
token = _current_project_dir.set(encoded)
monkeypatch.setattr(
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "too large" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_beyond_file_length(self):
"""Offset past end-of-file returns empty content."""
encoded = "-tmp-copilot-local-read-past-eof"
path = self._make_tool_result(encoded, "short.txt", b"abc")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 999, 10, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_zero_length_read(self):
"""Requesting zero characters returns empty content."""
encoded = "-tmp-copilot-local-read-zero-len"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 2, 0, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_json_extension(self):
"""JSON files get application/json MIME type, not hardcoded text/plain."""
encoded = "-tmp-copilot-local-read-json"
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "application/json"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_png_extension(self):
"""Binary .png files get image/png MIME type via mimetypes."""
encoded = "-tmp-copilot-local-read-png-mime"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "chart.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "image/png"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_explicit_sdk_cwd_parameter(self):
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
encoded = "-tmp-copilot-local-read-sdkcwd"
path = self._make_tool_result(encoded, "out.txt", b"content")
token = _current_project_dir.set(encoded)
try:
# Pass sdk_cwd explicitly — should still succeed because the path
# is under SDK_PROJECTS_DIR which is always allowed.
result = _read_local_tool_result(
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
)
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "content"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_with_no_length_reads_to_end(self):
"""When char_length is None, read from offset to end of file."""
encoded = "-tmp-copilot-local-read-offset-noLen"
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 5, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "56789"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
# ---------------------------------------------------------------------------
# ReadWorkspaceFileTool fallback to _read_local_tool_result
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
"""When _resolve_file returns ErrorResponse for an allowed local path,
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
user = setup_test_data["user"]
session = make_session(user.id)
# Create a real tool-result file on disk so the fallback can read it.
encoded = "-tmp-copilot-fallback-test"
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, "result.txt")
with open(filepath, "w") as f:
f.write("fallback content")
token = _current_project_dir.set(encoded)
try:
# Mock _resolve_file to return an ErrorResponse (simulating "file not
# found in workspace") so the fallback branch is exercised.
mock_resolve = AsyncMock(
return_value=ErrorResponse(
message="File not found at path: result.txt",
session_id=session.session_id,
)
)
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
path=filepath,
)
# Should have fallen back to _read_local_tool_result and succeeded.
assert isinstance(result, WorkspaceFileContentResponse), (
f"Expected fallback to local read, got {type(result).__name__}: "
f"{getattr(result, 'message', '')}"
)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "fallback content"
mock_resolve.assert_awaited_once()
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
user = setup_test_data["user"]
session = make_session(user.id)
fake_file_id = "fake-file-id-001"
fake_content = b"workspace content"
# Build a minimal file_info stub that the tool's happy-path needs.
class _FakeFileInfo:
id = fake_file_id
name = "result.json"
path = "/result.json"
mime_type = "text/plain"
size_bytes = len(fake_content)
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
mock_manager = AsyncMock()
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
with (
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
patch(
"backend.copilot.tools.workspace_files.get_workspace_manager",
AsyncMock(return_value=mock_manager),
),
patch(
"backend.copilot.tools.workspace_files._read_local_tool_result"
) as patched_local,
):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
file_id=fake_file_id,
)
# Fallback must not have been called.
patched_local.assert_not_called()
# Normal workspace path must have produced a content response.
assert isinstance(result, WorkspaceFileContentResponse)
assert base64.b64decode(result.content_base64) == fake_content

View File

@@ -70,10 +70,6 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens (standard: "text" key, fallback: "content")
text_val = item.get("text") or item.get("content", "")
tool_call_tokens += _tok_len(text_val, enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -149,16 +145,10 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
if max_tok < 1:
return ""
mid = enc.encode("")
if max_tok < 3:
return enc.decode(ids[:max_tok])
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -555,14 +545,6 @@ async def _summarize_messages_llm(
"- Actions taken and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
"critical for continuing the conversation:\n"
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
"- Image/media file paths from tool outputs\n"
"- URLs, API endpoints, and webhook addresses\n"
"- Resource IDs, session IDs, and identifiers\n"
"- Tool names that were called and their key parameters\n"
"- Environment variables, config keys, and credentials names (not values)\n\n"
"Include ONLY the sections below that have relevant content "
"(skip sections with nothing to report):\n\n"
"## 1. Primary Request and Intent\n"
@@ -570,8 +552,7 @@ async def _summarize_messages_llm(
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers. "
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions.\n\n"
"## 5. All User Messages\n"
@@ -585,7 +566,7 @@ async def _summarize_messages_llm(
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=2000,
max_tokens=1500,
temperature=0.3,
)
@@ -705,15 +686,11 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------