mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
59 Commits
fix/copilo
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c45687e10 | ||
|
|
869743ff0e | ||
|
|
46c35cfca6 | ||
|
|
8748b3e49d | ||
|
|
5f4e5eb207 | ||
|
|
2479de7ac9 | ||
|
|
f4dee98508 | ||
|
|
bd23caa116 | ||
|
|
17bbd18521 | ||
|
|
de73d89e39 | ||
|
|
29efcfb280 | ||
|
|
f1151c5cc1 | ||
|
|
11dbc08450 | ||
|
|
bca314cfbe | ||
|
|
c4a51d2804 | ||
|
|
e17e1616d9 | ||
|
|
ca0b3cde16 | ||
|
|
045096d863 | ||
|
|
fc844fde1f | ||
|
|
9642332332 | ||
|
|
47d91e915f | ||
|
|
df75e130da | ||
|
|
d0fc7ed3b2 | ||
|
|
9781aa93e3 | ||
|
|
f043fa7b6a | ||
|
|
ca4dad979d | ||
|
|
4559d13b29 | ||
|
|
4cc1baac54 | ||
|
|
9d1881d909 | ||
|
|
384b261e7f | ||
|
|
4cc0bbf472 | ||
|
|
3082f878fe | ||
|
|
33cd967e66 | ||
|
|
b599858dea | ||
|
|
629ecc9436 | ||
|
|
4b92fd09c9 | ||
|
|
41872e003b | ||
|
|
5dc8d6c848 | ||
|
|
8c8e596302 | ||
|
|
ad6e2f0ca1 | ||
|
|
d1ef92a79a | ||
|
|
15d36233b6 | ||
|
|
618dde9d02 | ||
|
|
39c0fece87 | ||
|
|
41591fd76f | ||
|
|
7d95321fd9 | ||
|
|
4ebc759f0a | ||
|
|
3e509847fd | ||
|
|
1023134458 | ||
|
|
8f0f6ced10 | ||
|
|
9f60fda37f | ||
|
|
b04f806760 | ||
|
|
0246623337 | ||
|
|
696f533e2e | ||
|
|
8c7b077753 | ||
|
|
a1f34316c7 | ||
|
|
152f54f33d | ||
|
|
6baeb117f7 | ||
|
|
2adeb63ebc |
@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B sandbox template to use for copilot sessions.",
|
||||
)
|
||||
e2b_sandbox_timeout: int = Field(
|
||||
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
|
||||
default=420, # 7 min safety net — allows headroom for compaction retries
|
||||
description="E2B sandbox running-time timeout (seconds). "
|
||||
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
|
||||
"mechanism; this is the safety net.",
|
||||
|
||||
@@ -17,17 +17,8 @@ from backend.util.workspace import WorkspaceManager
|
||||
if TYPE_CHECKING:
|
||||
from e2b import AsyncSandbox
|
||||
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
# Compiled UUID pattern for validating conversation directory names.
|
||||
# Kept as a module-level constant so the security-relevant pattern is easy
|
||||
# to audit in one place and avoids recompilation on every call.
|
||||
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
|
||||
# Allowed base directory for the Read tool.
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
# Encoded project-directory name for the current session (e.g.
|
||||
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
|
||||
@@ -44,20 +35,11 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
|
||||
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
|
||||
|
||||
|
||||
def encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The Claude CLI encodes the absolute cwd as a directory name by replacing
|
||||
every non-alphanumeric character with ``-``. For example
|
||||
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
|
||||
"""
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does."""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
# Keep the private alias for internal callers (backwards compat).
|
||||
_encode_cwd_for_cli = encode_cwd_for_cli
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
@@ -118,9 +100,7 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
Allowed:
|
||||
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
|
||||
The SDK nests tool-results under a conversation UUID directory;
|
||||
the UUID segment is validated with ``_UUID_RE``.
|
||||
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
|
||||
"""
|
||||
if not path:
|
||||
return False
|
||||
@@ -139,22 +119,10 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
|
||||
|
||||
encoded = _current_project_dir.get("")
|
||||
if encoded:
|
||||
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
|
||||
# Defence-in-depth: ensure project_dir didn't escape the base.
|
||||
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
|
||||
return False
|
||||
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
|
||||
# The SDK always creates a conversation UUID directory between
|
||||
# the project dir and tool-results/.
|
||||
if resolved.startswith(project_dir + os.sep):
|
||||
relative = resolved[len(project_dir) + 1 :]
|
||||
parts = relative.split(os.sep)
|
||||
# Require exactly: [<uuid>, "tool-results", <file>, ...]
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and _UUID_RE.match(parts[0])
|
||||
and parts[1] == "tool-results"
|
||||
):
|
||||
return True
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
if resolved == tool_results_dir or resolved.startswith(
|
||||
tool_results_dir + os.sep
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import (
|
||||
SDK_PROJECTS_DIR,
|
||||
_SDK_PROJECTS_DIR,
|
||||
_current_project_dir,
|
||||
get_current_sandbox,
|
||||
get_execution_context,
|
||||
@@ -104,13 +104,11 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
|
||||
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
|
||||
def test_is_allowed_local_path_tool_results_dir():
|
||||
"""Files under the tool-results directory for the current project are allowed."""
|
||||
encoded = "test-encoded-dir"
|
||||
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
path = os.path.join(tool_results_dir, "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -119,22 +117,10 @@ def test_is_allowed_local_path_tool_results_with_uuid():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
|
||||
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
"""A path adjacent to tool-results/ but not inside it is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
@@ -143,21 +129,6 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
|
||||
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
|
||||
encoded = "test-encoded-dir"
|
||||
uuid_str = "12345678-1234-5678-9abc-def012345678"
|
||||
path = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
|
||||
)
|
||||
|
||||
_current_project_dir.set(encoded)
|
||||
try:
|
||||
assert not is_allowed_local_path(path, sdk_cwd=None)
|
||||
finally:
|
||||
_current_project_dir.set("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -152,13 +152,6 @@ def _build_storage_supplement(
|
||||
|
||||
### File persistence
|
||||
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
|
||||
|
||||
### SDK tool-result files
|
||||
When tool outputs are large, the SDK truncates them and saves the full output to
|
||||
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
|
||||
always use `read_file` or `Read` (NOT `read_workspace_file`).
|
||||
`read_workspace_file` reads from cloud workspace storage, where SDK
|
||||
tool-results are NOT stored.
|
||||
{_SHARED_TOOL_NOTES}"""
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class ResponseType(str, Enum):
|
||||
ERROR = "error"
|
||||
USAGE = "usage"
|
||||
HEARTBEAT = "heartbeat"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class StreamBaseResponse(BaseModel):
|
||||
@@ -232,3 +233,26 @@ class StreamHeartbeat(StreamBaseResponse):
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE comment format to keep connection alive."""
|
||||
return ": heartbeat\n\n"
|
||||
|
||||
|
||||
class StreamStatus(StreamBaseResponse):
|
||||
"""Transient status notification shown to the user during long operations.
|
||||
|
||||
Used to provide feedback when the backend performs behind-the-scenes work
|
||||
(e.g., compacting conversation context on a retry) that would otherwise
|
||||
leave the user staring at an unexplained pause.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.STATUS
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Encode as an SSE comment so the AI SDK stream parser ignores it.
|
||||
|
||||
The frontend AI SDK validates every ``data:`` line against a strict
|
||||
Zod union of known chunk types. ``"status"`` is not in that union,
|
||||
so sending it as ``data:`` would cause a schema-validation error that
|
||||
breaks the entire stream. Using an SSE comment (``:``) keeps the
|
||||
connection alive and is silently discarded by ``EventSource`` parsers.
|
||||
"""
|
||||
return f": status {self.message}\n\n"
|
||||
|
||||
@@ -12,6 +12,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
|
||||
from ..model import ChatMessage, ChatSession
|
||||
@@ -119,14 +120,12 @@ def filter_compaction_messages(
|
||||
filtered: list[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
real_calls: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
|
||||
compaction_ids.add(tc.get("id", ""))
|
||||
real_calls = [
|
||||
tc
|
||||
for tc in msg.tool_calls
|
||||
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
|
||||
]
|
||||
else:
|
||||
real_calls.append(tc)
|
||||
if not real_calls and not msg.content:
|
||||
continue
|
||||
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
|
||||
@@ -222,6 +221,7 @@ class CompactionTracker:
|
||||
|
||||
def reset_for_query(self) -> None:
|
||||
"""Reset per-query state before a new SDK query."""
|
||||
self._compact_start.clear()
|
||||
self._done = False
|
||||
self._start_emitted = False
|
||||
self._tool_call_id = ""
|
||||
|
||||
41
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
41
autogpt_platform/backend/backend/copilot/sdk/conftest.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Shared test fixtures for copilot SDK tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
|
||||
|
||||
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
|
||||
"""Build a minimal valid JSONL transcript from (role, content) pairs.
|
||||
|
||||
Use this helper in any copilot SDK test that needs a well-formed
|
||||
transcript without hitting the real storage layer.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str | None = None
|
||||
for role, content in pairs:
|
||||
uid = str(uuid4())
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
msg: dict = {"role": role, "content": content}
|
||||
if role == "assistant":
|
||||
msg.update(
|
||||
{
|
||||
"model": "",
|
||||
"id": f"msg_{uid[:8]}",
|
||||
"type": "message",
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
}
|
||||
)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": msg,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -26,41 +26,6 @@ from backend.copilot.context import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_sandbox_symlink_escape(
|
||||
sandbox: Any,
|
||||
parent: str,
|
||||
) -> str | None:
|
||||
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
|
||||
|
||||
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
|
||||
``readlink -f`` follows actual symlinks on the sandbox filesystem.
|
||||
|
||||
Returns the canonical parent path, or ``None`` if the path escapes
|
||||
``E2B_WORKDIR``.
|
||||
|
||||
Note: There is an inherent TOCTOU window between this check and the
|
||||
subsequent ``sandbox.files.write()``. A symlink could theoretically be
|
||||
replaced between the two operations. This is acceptable in the E2B
|
||||
sandbox model since the sandbox is single-user and ephemeral.
|
||||
"""
|
||||
canonical_res = await sandbox.commands.run(
|
||||
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
|
||||
cwd=E2B_WORKDIR,
|
||||
timeout=5,
|
||||
)
|
||||
canonical_parent = (canonical_res.stdout or "").strip()
|
||||
if (
|
||||
canonical_res.exit_code != 0
|
||||
or not canonical_parent
|
||||
or (
|
||||
canonical_parent != E2B_WORKDIR
|
||||
and not canonical_parent.startswith(E2B_WORKDIR + "/")
|
||||
)
|
||||
):
|
||||
return None
|
||||
return canonical_parent
|
||||
|
||||
|
||||
def _get_sandbox():
|
||||
return get_current_sandbox()
|
||||
|
||||
@@ -141,10 +106,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
parent = os.path.dirname(remote)
|
||||
if parent and parent != E2B_WORKDIR:
|
||||
await sandbox.files.make_dir(parent)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
await sandbox.files.write(remote, content)
|
||||
except Exception as exc:
|
||||
return _mcp(f"Failed to write {remote}: {exc}", error=True)
|
||||
@@ -169,12 +130,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
sandbox, remote = result
|
||||
|
||||
parent = os.path.dirname(remote)
|
||||
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
|
||||
if canonical_parent is None:
|
||||
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
|
||||
remote = os.path.join(canonical_parent, os.path.basename(remote))
|
||||
|
||||
try:
|
||||
raw: bytes = await sandbox.files.read(remote, format="bytes")
|
||||
content = raw.decode("utf-8", errors="replace")
|
||||
|
||||
@@ -4,19 +4,15 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.context import _current_project_dir
|
||||
|
||||
from .e2b_file_tools import _read_local, resolve_sandbox_path
|
||||
|
||||
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
||||
|
||||
from .e2b_file_tools import (
|
||||
_check_sandbox_symlink_escape,
|
||||
_read_local,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
|
||||
@@ -25,48 +21,46 @@ from .e2b_file_tools import (
|
||||
|
||||
class TestResolveSandboxPath:
|
||||
def test_relative_path_resolved(self):
|
||||
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
|
||||
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
|
||||
|
||||
def test_absolute_within_sandbox(self):
|
||||
assert (
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
|
||||
)
|
||||
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
|
||||
|
||||
def test_workdir_itself(self):
|
||||
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
|
||||
assert resolve_sandbox_path("/home/user") == "/home/user"
|
||||
|
||||
def test_relative_dotslash(self):
|
||||
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
|
||||
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
|
||||
|
||||
def test_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("../../etc/passwd")
|
||||
|
||||
def test_absolute_traversal_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/user/../../etc/passwd")
|
||||
|
||||
def test_absolute_outside_sandbox_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/etc/passwd")
|
||||
|
||||
def test_root_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/")
|
||||
|
||||
def test_home_other_user_blocked(self):
|
||||
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
|
||||
with pytest.raises(ValueError, match="must be within /home/user"):
|
||||
resolve_sandbox_path("/home/other/file.txt")
|
||||
|
||||
def test_deep_nested_allowed(self):
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
|
||||
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
|
||||
|
||||
def test_trailing_slash_normalised(self):
|
||||
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
|
||||
assert resolve_sandbox_path("src/") == "/home/user/src"
|
||||
|
||||
def test_double_dots_within_sandbox_ok(self):
|
||||
"""Path that resolves back within E2B_WORKDIR is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
|
||||
"""Path that resolves back within /home/user is allowed."""
|
||||
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -79,13 +73,9 @@ class TestResolveSandboxPath:
|
||||
|
||||
|
||||
class TestReadLocal:
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
|
||||
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, filename)
|
||||
with open(filepath, "w") as f:
|
||||
@@ -117,9 +107,7 @@ class TestReadLocal:
|
||||
def test_read_nonexistent_tool_results(self):
|
||||
"""A tool-results path that doesn't exist returns FileNotFoundError."""
|
||||
encoded = "-tmp-copilot-e2b-test-nofile"
|
||||
tool_results_dir = os.path.join(
|
||||
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
|
||||
)
|
||||
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
|
||||
os.makedirs(tool_results_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
@@ -129,7 +117,7 @@ class TestReadLocal:
|
||||
assert "not found" in result["content"][0]["text"].lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
os.rmdir(tool_results_dir)
|
||||
|
||||
def test_read_traversal_path_blocked(self):
|
||||
"""A traversal attempt that escapes allowed directories is blocked."""
|
||||
@@ -164,66 +152,3 @@ class TestReadLocal:
|
||||
"""Without _current_project_dir set, all paths are blocked."""
|
||||
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
|
||||
assert result["isError"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_sandbox_symlink_escape — symlink escape detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
|
||||
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
|
||||
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
|
||||
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
|
||||
return SimpleNamespace(commands=commands)
|
||||
|
||||
|
||||
class TestCheckSandboxSymlinkEscape:
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonical_path_within_workdir_returns_path(self):
|
||||
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result == f"{E2B_WORKDIR}/src"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workdir_itself_returns_workdir(self):
|
||||
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
|
||||
assert result == E2B_WORKDIR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_escape_returns_none(self):
|
||||
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
|
||||
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonzero_exit_code_returns_none(self):
|
||||
"""A non-zero exit code from readlink -f returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=1)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stdout_returns_none(self):
|
||||
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
|
||||
sandbox = _make_sandbox(stdout="", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefix_collision_returns_none(self):
|
||||
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deeply_nested_path_within_workdir(self):
|
||||
"""Deep nested paths inside E2B_WORKDIR are allowed."""
|
||||
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
|
||||
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
|
||||
assert result == f"{E2B_WORKDIR}/a/b/c/d"
|
||||
|
||||
@@ -0,0 +1,552 @@
|
||||
"""Tests for retry logic and transcript compaction helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenAssistantContent:
|
||||
def test_text_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "text", "text": "World"},
|
||||
]
|
||||
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
|
||||
|
||||
def test_tool_use_blocks(self):
|
||||
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
|
||||
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
|
||||
|
||||
def test_mixed_blocks(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "Let me read that."},
|
||||
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "Let me read that." in result
|
||||
assert "[tool_use: Read]" in result
|
||||
|
||||
def test_raw_strings(self):
|
||||
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [
|
||||
{"type": "text", "text": "See this image:"},
|
||||
{"type": "image", "source": {"type": "base64", "data": "..."}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "See this image:" in result
|
||||
assert "[__image__]" in result
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_assistant_content([]) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_tool_result_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFlattenToolResultContent:
|
||||
def test_tool_result_with_text(self):
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": [{"type": "text", "text": "file contents here"}],
|
||||
}
|
||||
]
|
||||
assert _flatten_tool_result_content(blocks) == "file contents here"
|
||||
|
||||
def test_tool_result_with_string_content(self):
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
|
||||
assert _flatten_tool_result_content(blocks) == "ok"
|
||||
|
||||
def test_text_block(self):
|
||||
blocks = [{"type": "text", "text": "plain text"}]
|
||||
assert _flatten_tool_result_content(blocks) == "plain text"
|
||||
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
def test_nested_dict_without_text(self):
|
||||
"""Dict blocks without text key use json.dumps fallback."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [{"type": "image", "source": "data:..."}],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "image" in result # json.dumps fallback
|
||||
|
||||
def test_unknown_block_type_preserved_as_placeholder(self):
|
||||
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
assert "[__image__]" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _transcript_to_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
|
||||
"""Build a JSONL line for testing."""
|
||||
uid = str(uuid4())
|
||||
msg: dict = {"role": role, "content": content}
|
||||
msg.update(kwargs)
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": None,
|
||||
"message": msg,
|
||||
}
|
||||
return json.dumps(entry, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestTranscriptToMessages:
|
||||
def test_basic_roundtrip(self):
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert messages[0] == {"role": "user", "content": "Hello"}
|
||||
assert messages[1] == {"role": "assistant", "content": "Hi"}
|
||||
|
||||
def test_skips_strippable_types(self):
|
||||
"""Progress and metadata entries are excluded."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
json.dumps(
|
||||
{
|
||||
"type": "progress",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"role": "assistant", "content": "..."},
|
||||
}
|
||||
),
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _transcript_to_messages("") == []
|
||||
|
||||
def test_tool_result_content(self):
|
||||
"""User entries with tool_result content blocks are flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "123",
|
||||
"content": "tool output",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessagesToTranscript:
|
||||
def test_produces_valid_jsonl(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert "type" in parsed
|
||||
assert "uuid" in parsed
|
||||
assert "message" in parsed
|
||||
|
||||
def test_assistant_has_proper_structure(self):
|
||||
messages = [{"role": "assistant", "content": "Hello"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "assistant"
|
||||
msg = entry["message"]
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["type"] == "message"
|
||||
assert msg["stop_reason"] == "end_turn"
|
||||
assert isinstance(msg["content"], list)
|
||||
assert msg["content"][0]["type"] == "text"
|
||||
|
||||
def test_user_has_plain_content(self):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = _messages_to_transcript(messages)
|
||||
entry = json.loads(result.strip())
|
||||
assert entry["type"] == "user"
|
||||
assert entry["message"]["content"] == "Hi"
|
||||
|
||||
def test_parent_uuid_chain(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{"role": "user", "content": "C"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
lines = result.strip().split("\n")
|
||||
entries = [json.loads(line) for line in lines]
|
||||
assert entries[0]["parentUuid"] == ""
|
||||
assert entries[1]["parentUuid"] == entries[0]["uuid"]
|
||||
assert entries[2]["parentUuid"] == entries[1]["uuid"]
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert _messages_to_transcript([]) == ""
|
||||
|
||||
def test_output_is_valid_transcript(self):
|
||||
"""Output should pass validate_transcript if it has assistant entries."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
|
||||
def test_roundtrip_to_messages(self):
|
||||
"""Messages → transcript → messages preserves structure."""
|
||||
original = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
transcript = _messages_to_transcript(original)
|
||||
restored = _transcript_to_messages(transcript)
|
||||
assert len(restored) == len(original)
|
||||
for orig, rest in zip(original, restored):
|
||||
assert orig["role"] == rest["role"]
|
||||
assert orig["content"] == rest["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compact_transcript
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_few_messages_returns_none(self):
|
||||
"""compact_transcript returns None when transcript has < 2 messages."""
|
||||
transcript = _build_transcript([("user", "Hello")])
|
||||
with patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_compacted(self):
|
||||
"""When compress_context says no compaction needed, returns None.
|
||||
The compressor couldn't reduce it, so retrying with the same
|
||||
content would fail identically."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi there"),
|
||||
]
|
||||
)
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": False,
|
||||
"messages": [],
|
||||
"original_token_count": 100,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 0,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_compacted_transcript(self):
|
||||
"""When compaction succeeds, returns a valid compacted transcript."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
("user", "More"),
|
||||
("assistant", "Details"),
|
||||
]
|
||||
)
|
||||
compacted_msgs = [
|
||||
{"role": "user", "content": "[summary]"},
|
||||
{"role": "assistant", "content": "Summarized response"},
|
||||
]
|
||||
mock_result = type(
|
||||
"CompressResult",
|
||||
(),
|
||||
{
|
||||
"was_compacted": True,
|
||||
"messages": compacted_msgs,
|
||||
"original_token_count": 500,
|
||||
"token_count": 100,
|
||||
"messages_summarized": 2,
|
||||
"messages_dropped": 0,
|
||||
},
|
||||
)()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is not None
|
||||
assert validate_transcript(result)
|
||||
msgs = _transcript_to_messages(result)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[1]["content"] == "Summarized response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_compression_failure(self):
|
||||
"""When _run_compression raises, returns None."""
|
||||
transcript = _build_transcript(
|
||||
[
|
||||
("user", "Hello"),
|
||||
("assistant", "Hi"),
|
||||
]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("LLM unavailable"),
|
||||
),
|
||||
):
|
||||
result = await compact_transcript(transcript, model="test-model")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Unit tests for _is_prompt_too_long pattern matching."""
|
||||
|
||||
def test_prompt_is_too_long(self):
|
||||
err = RuntimeError("prompt is too long for model context")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_request_too_large(self):
|
||||
err = Exception("request too large: 250000 tokens")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_maximum_context_length(self):
|
||||
err = ValueError("maximum context length exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
err = Exception("context_length_exceeded")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_tokens_exceed(self):
|
||||
err = Exception("input tokens exceed the max_tokens limit")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_input_is_too_long(self):
|
||||
err = Exception("input is too long for the model")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_content_length_exceeds(self):
|
||||
err = Exception("content length exceeds maximum")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_unrelated_error_returns_false(self):
|
||||
err = RuntimeError("network timeout")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_auth_error_returns_false(self):
|
||||
err = Exception("authentication failed: invalid API key")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
|
||||
def test_chained_exception_detected(self):
|
||||
"""Prompt-too-long error wrapped in another exception is detected."""
|
||||
inner = RuntimeError("prompt is too long")
|
||||
outer = Exception("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
err = Exception("PROMPT IS TOO LONG")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_old_max_tokens_exceeded_not_matched(self):
|
||||
"""The old broad 'max_tokens_exceeded' pattern was removed.
|
||||
Only 'input tokens exceed' should match now."""
|
||||
err = Exception("max_tokens_exceeded")
|
||||
assert _is_prompt_too_long(err) is False
|
||||
@@ -226,7 +226,7 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
1186
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
1186
autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,7 @@ def _validate_workspace_path(
|
||||
Delegates to :func:`is_allowed_local_path` which permits:
|
||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
||||
- The current session's tool-results directory
|
||||
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
|
||||
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
|
||||
"""
|
||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
||||
if not path:
|
||||
@@ -52,7 +52,7 @@ def _validate_workspace_path(
|
||||
if is_allowed_local_path(path, sdk_cwd):
|
||||
return {}
|
||||
|
||||
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
|
||||
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
|
||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
||||
@@ -71,7 +71,7 @@ def _validate_tool_access(
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
logger.warning("Blocked tool access attempt: %s", tool_name)
|
||||
return _deny(
|
||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
||||
"This is enforced by the platform and cannot be bypassed. "
|
||||
@@ -89,7 +89,9 @@ def _validate_tool_access(
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
"Blocked dangerous pattern in tool input: %s in %s",
|
||||
pattern,
|
||||
tool_name,
|
||||
)
|
||||
return _deny(
|
||||
"[SECURITY] Input contains a blocked pattern. "
|
||||
@@ -111,7 +113,9 @@ def _validate_user_isolation(
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
logger.warning(
|
||||
"Blocked path traversal attempt: %s by user %s", path, user_id
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
@@ -170,7 +174,7 @@ def create_security_hooks(
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
logger.info("[SDK] Blocked background Task, user=%s", user_id)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
@@ -181,7 +185,9 @@ def create_security_hooks(
|
||||
)
|
||||
if len(task_tool_use_ids) >= max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
"[SDK] Task limit reached (%d), user=%s",
|
||||
max_subtasks,
|
||||
user_id,
|
||||
)
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
@@ -212,7 +218,7 @@ def create_security_hooks(
|
||||
if tool_name == "Task" and tool_use_id is not None:
|
||||
task_tool_use_ids.add(tool_use_id)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
|
||||
@@ -282,8 +288,11 @@ def create_security_hooks(
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
|
||||
tool_name,
|
||||
str(error).replace("\n", "").replace("\r", ""),
|
||||
user_id,
|
||||
tool_use_id,
|
||||
)
|
||||
|
||||
_release_task_slot(tool_name, tool_use_id)
|
||||
@@ -301,20 +310,19 @@ def create_security_hooks(
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
# Sanitize untrusted input before logging to prevent log injection
|
||||
trigger = (
|
||||
str(input_data.get("trigger", "auto"))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
transcript_path = (
|
||||
str(input_data.get("transcript_path", ""))
|
||||
.replace("\n", "")
|
||||
.replace("\r", "")
|
||||
)
|
||||
logger.info(
|
||||
"[SDK] Context compaction triggered: %s, user=%s, "
|
||||
"transcript_path=%s",
|
||||
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
|
||||
trigger,
|
||||
user_id,
|
||||
transcript_path,
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
# is_allowed_local_path requires the session's encoded cwd to be set
|
||||
token = _current_project_dir.set("-tmp-copilot-abc123")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,283 @@
|
||||
"""Unit tests for extracted service helpers.
|
||||
|
||||
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
|
||||
and the ``ReducedContext`` named tuple.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_iter_sdk_messages,
|
||||
_reduce_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
def test_direct_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("prompt is too long")) is True
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("network timeout")) is False
|
||||
|
||||
def test_request_too_large(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("request too large for model")) is True
|
||||
|
||||
def test_context_length_exceeded(self) -> None:
|
||||
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
|
||||
|
||||
def test_max_tokens_exceeded_not_matched(self) -> None:
|
||||
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
|
||||
|
||||
def test_max_tokens_config_error_no_match(self) -> None:
|
||||
"""'max_tokens must be at least 1' should NOT match."""
|
||||
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
|
||||
|
||||
def test_chained_cause(self) -> None:
|
||||
inner = Exception("prompt is too long")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_chained_context(self) -> None:
|
||||
inner = Exception("request too large")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_deep_chain(self) -> None:
|
||||
bottom = Exception("maximum context length")
|
||||
middle = RuntimeError("middle")
|
||||
middle.__cause__ = bottom
|
||||
top = ValueError("top")
|
||||
top.__cause__ = middle
|
||||
assert _is_prompt_too_long(top) is True
|
||||
|
||||
def test_chain_no_match(self) -> None:
|
||||
inner = Exception("rate limit exceeded")
|
||||
outer = RuntimeError("wrapped")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
|
||||
def test_cycle_detection(self) -> None:
|
||||
"""Exception chain with a cycle should not infinite-loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # cycle
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_all_patterns(self) -> None:
|
||||
patterns = [
|
||||
"prompt is too long",
|
||||
"request too large",
|
||||
"maximum context length",
|
||||
"context_length_exceeded",
|
||||
"input tokens exceed",
|
||||
"input is too long",
|
||||
"content length exceeds",
|
||||
]
|
||||
for pattern in patterns:
|
||||
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _reduce_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReduceContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_retry_compaction_success(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value="/tmp/resume.jsonl",
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert isinstance(ctx, ReducedContext)
|
||||
assert ctx.use_resume is True
|
||||
assert ctx.resume_file == "/tmp/resume.jsonl"
|
||||
assert ctx.transcript_lost is False
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_fails_drops_transcript(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.resume_file is None
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_tried_compaction_skips(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.tried_compaction is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_transcript_drops(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
|
||||
|
||||
assert ctx.use_resume is False
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_returns_same_content_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=transcript, # same content
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_tempfile_fails_drops(self) -> None:
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.compact_transcript",
|
||||
new_callable=AsyncMock,
|
||||
return_value=compacted,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.validate_transcript",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.service.write_transcript_to_tempfile",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
ctx = await _reduce_context(
|
||||
transcript, False, "sess-123", "/tmp/cwd", "[test]"
|
||||
)
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIterSdkMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_messages(self) -> None:
|
||||
messages = ["msg1", "msg2", "msg3"]
|
||||
client = AsyncMock()
|
||||
|
||||
async def _fake_receive() -> AsyncGenerator[str]:
|
||||
for m in messages:
|
||||
yield m
|
||||
|
||||
client.receive_response = _fake_receive
|
||||
result = [msg async for msg in _iter_sdk_messages(client)]
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_on_timeout(self) -> None:
|
||||
"""Yields None when asyncio.wait times out."""
|
||||
client = AsyncMock()
|
||||
received: list = []
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
await asyncio.sleep(100) # never completes
|
||||
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
|
||||
count = 0
|
||||
async for msg in _iter_sdk_messages(client):
|
||||
received.append(msg)
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
assert all(m is None for m in received)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_propagates(self) -> None:
|
||||
client = AsyncMock()
|
||||
|
||||
async def _error_receive() -> AsyncGenerator[str]:
|
||||
raise RuntimeError("SDK crash")
|
||||
yield # pragma: no cover — unreachable, yield makes this an async generator
|
||||
|
||||
client.receive_response = _error_receive
|
||||
|
||||
with pytest.raises(RuntimeError, match="SDK crash"):
|
||||
async for _ in _iter_sdk_messages(client):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cleanup_on_break(self) -> None:
|
||||
"""Pending task is cancelled when generator is closed."""
|
||||
client = AsyncMock()
|
||||
|
||||
async def _slow_receive() -> AsyncGenerator[str]:
|
||||
yield "first"
|
||||
await asyncio.sleep(100)
|
||||
yield "second"
|
||||
|
||||
client.receive_response = _slow_receive
|
||||
|
||||
gen = _iter_sdk_messages(client)
|
||||
first = await gen.__anext__()
|
||||
assert first == "first"
|
||||
await gen.aclose() # should cancel pending task cleanly
|
||||
@@ -288,90 +288,3 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_sdk_tool_results — orchestration + rate-limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupSdkToolResults:
|
||||
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
|
||||
|
||||
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
|
||||
_CWD_PREFIX = "/tmp/copilot-"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_cwd_directory(self):
|
||||
"""Cleanup removes the session working directory."""
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-cleanup-remove"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
assert not os.path.exists(cwd)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_runs_when_interval_elapsed(self):
|
||||
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-elapsed"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to a time far in the past
|
||||
svc_mod._last_sweep_time = 0.0
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_skipped_within_interval(self):
|
||||
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
|
||||
import time
|
||||
|
||||
import backend.copilot.sdk.service as svc_mod
|
||||
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
cwd = "/tmp/copilot-test-sweep-ratelimit"
|
||||
os.makedirs(cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
# Set last sweep to now — interval not elapsed
|
||||
svc_mod._last_sweep_time = time.time()
|
||||
await _cleanup_sdk_tool_results(cwd)
|
||||
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_outside_prefix(self, tmp_path):
|
||||
"""Cleanup rejects a cwd that does not start with the expected prefix."""
|
||||
from .service import _cleanup_sdk_tool_results
|
||||
|
||||
evil_cwd = str(tmp_path / "evil-path")
|
||||
os.makedirs(evil_cwd, exist_ok=True)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
|
||||
) as mock_sweep:
|
||||
await _cleanup_sdk_tool_results(evil_cwd)
|
||||
|
||||
# Directory should NOT have been removed (rejected early)
|
||||
assert os.path.exists(evil_cwd)
|
||||
mock_sweep.assert_not_called()
|
||||
|
||||
@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
The 2.0 s default was chosen based on production metrics: the original
|
||||
0.5 s caused frequent timeouts under load (parallel tool calls, large
|
||||
outputs). 2.0 s gives a comfortable margin while still failing fast
|
||||
when the hook genuinely will not fire.
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
|
||||
try:
|
||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
|
||||
)
|
||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
||||
|
||||
return tool_handler
|
||||
@@ -285,7 +287,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
resolved = os.path.realpath(os.path.expanduser(file_path))
|
||||
try:
|
||||
with open(resolved, encoding="utf-8", errors="replace") as f:
|
||||
with open(resolved) as f:
|
||||
selected = list(itertools.islice(f, offset, offset + limit))
|
||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
||||
# don't delete here — the SDK may read in multiple chunks.
|
||||
|
||||
@@ -10,6 +10,9 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -17,8 +20,12 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,7 +106,14 @@ def strip_progress_entries(content: str) -> str:
|
||||
continue
|
||||
parent = entry.get("parentUuid", "")
|
||||
original_parent = parent
|
||||
while parent in stripped_uuids:
|
||||
# seen_parents is local per-entry (not shared across iterations) so
|
||||
# it can only detect cycles within a single ancestry walk, not across
|
||||
# entries. This is intentional: each entry's parent chain is
|
||||
# independent, and reusing a global set would incorrectly short-circuit
|
||||
# valid re-use of the same UUID as a parent in different subtrees.
|
||||
seen_parents: set[str] = set()
|
||||
while parent in stripped_uuids and parent not in seen_parents:
|
||||
seen_parents.add(parent)
|
||||
parent = uuid_to_parent.get(parent, "")
|
||||
if parent != original_parent:
|
||||
entry["parentUuid"] = parent
|
||||
@@ -151,110 +165,44 @@ def _projects_base() -> str:
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
|
||||
|
||||
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
|
||||
def _cli_project_dir(sdk_cwd: str) -> str | None:
|
||||
"""Return the CLI's project directory for a given working directory.
|
||||
|
||||
|
||||
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
|
||||
|
||||
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
|
||||
directory. These are intentionally kept across turns so the model can read
|
||||
tool-result files via ``--resume``. However, after a session ends they
|
||||
become stale. This function sweeps old ones to prevent unbounded disk
|
||||
growth.
|
||||
|
||||
When *encoded_cwd* is provided the sweep is scoped to that single
|
||||
directory, making the operation safe in multi-tenant environments where
|
||||
multiple copilot sessions share the same host. Without it the function
|
||||
falls back to sweeping all directories matching the copilot naming pattern
|
||||
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
|
||||
|
||||
Returns the number of directories removed.
|
||||
Returns ``None`` if the path would escape the projects base.
|
||||
"""
|
||||
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
projects_base = _projects_base()
|
||||
if not os.path.isdir(projects_base):
|
||||
return 0
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
now = time.time()
|
||||
removed = 0
|
||||
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
if "-tmp-copilot-" not in target.name:
|
||||
logger.warning(
|
||||
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
|
||||
)
|
||||
return 0
|
||||
try:
|
||||
# st_mtime is used as a proxy for session activity. Claude CLI writes
|
||||
# its JSONL transcript into this directory during each turn, so mtime
|
||||
# advances on every turn. A directory whose mtime is older than
|
||||
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
|
||||
# and is safe to remove (the session cannot --resume after cleanup).
|
||||
age = now - target.stat().st_mtime
|
||||
except OSError:
|
||||
return 0
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
return 0
|
||||
try:
|
||||
shutil.rmtree(target, ignore_errors=True)
|
||||
removed = 1
|
||||
except OSError:
|
||||
pass
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
|
||||
target.name,
|
||||
int(age),
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
)
|
||||
return removed
|
||||
|
||||
# Unscoped fallback: sweep all copilot dirs across the projects base.
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
|
||||
for entry in entries:
|
||||
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
|
||||
break
|
||||
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
|
||||
# -private-tmp-copilot-).
|
||||
if "-tmp-copilot-" not in entry.name:
|
||||
continue
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
try:
|
||||
# See the scoped-mode comment above: st_mtime advances on every turn,
|
||||
# so a stale mtime reliably indicates an inactive session.
|
||||
age = now - entry.stat().st_mtime
|
||||
except OSError:
|
||||
continue
|
||||
if age < _STALE_PROJECT_DIR_SECONDS:
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.rmtree(entry, ignore_errors=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if removed:
|
||||
logger.info(
|
||||
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
|
||||
removed,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] Project dir escaped projects base: %s", project_dir
|
||||
)
|
||||
return removed
|
||||
return None
|
||||
return project_dir
|
||||
|
||||
|
||||
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
|
||||
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
|
||||
try:
|
||||
resolved_base = Path(project_dir).resolve()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
|
||||
return []
|
||||
|
||||
result: list[Path] = []
|
||||
for candidate in Path(project_dir).glob("*.jsonl"):
|
||||
try:
|
||||
resolved = candidate.resolve()
|
||||
if resolved.is_relative_to(resolved_base):
|
||||
result.append(resolved)
|
||||
except (OSError, RuntimeError) as e:
|
||||
logger.debug(
|
||||
"[Transcript] Skipping invalid CLI session candidate %s: %s",
|
||||
candidate,
|
||||
e,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
@@ -321,6 +269,63 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
return entries
|
||||
|
||||
|
||||
def read_cli_session_file(sdk_cwd: str) -> str | None:
|
||||
"""Read the CLI's own session file, which reflects any compaction.
|
||||
|
||||
The CLI writes its session transcript to
|
||||
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
|
||||
Since each SDK turn uses a unique ``sdk_cwd``, there should be
|
||||
exactly one ``.jsonl`` file in that directory.
|
||||
|
||||
Returns the file content, or ``None`` if not found.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir or not os.path.isdir(project_dir):
|
||||
return None
|
||||
|
||||
jsonl_files = _safe_glob_jsonl(project_dir)
|
||||
if not jsonl_files:
|
||||
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
|
||||
return None
|
||||
|
||||
# Pick the most recently modified file (should be only one per turn).
|
||||
try:
|
||||
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
|
||||
return None
|
||||
|
||||
try:
|
||||
content = session_file.read_text()
|
||||
logger.info(
|
||||
"[Transcript] Read CLI session file: %s (%d bytes)",
|
||||
session_file,
|
||||
len(content),
|
||||
)
|
||||
return content
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
project_dir = _cli_project_dir(sdk_cwd)
|
||||
if not project_dir:
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
|
||||
else:
|
||||
logger.debug("[Transcript] Project dir not found: %s", project_dir)
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -336,7 +341,7 @@ def write_transcript_to_tempfile(
|
||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
||||
real_cwd = os.path.realpath(cwd)
|
||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
||||
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -346,17 +351,17 @@ def write_transcript_to_tempfile(
|
||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
||||
)
|
||||
if not jsonl_path.startswith(real_cwd):
|
||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
||||
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
|
||||
return None
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
f.write(transcript_content)
|
||||
|
||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
||||
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
|
||||
return jsonl_path
|
||||
|
||||
except OSError as e:
|
||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
||||
logger.warning("[Transcript] Failed to write resume file: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@@ -417,8 +422,6 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
wid, fid, fname = parts
|
||||
if isinstance(backend, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
@@ -457,17 +460,15 @@ async def upload_transcript(
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
# Strip metadata entries (progress, file-history-snapshot, etc.)
|
||||
# Note: SDK-built transcripts shouldn't have these, but strip for safety
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types: list[str] = []
|
||||
for line in stripped.strip().split("\n"):
|
||||
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
|
||||
entry_types.append(entry.get("type", "?"))
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
@@ -503,11 +504,14 @@ async def upload_transcript(
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
f"{log_prefix} Uploaded {len(encoded)}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count})"
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -521,8 +525,6 @@ async def download_transcript(
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -530,10 +532,10 @@ async def download_transcript(
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"{log_prefix} No transcript in storage")
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
@@ -545,10 +547,14 @@ async def download_transcript(
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, Exception):
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
@@ -562,8 +568,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
@@ -580,3 +584,280 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
|
||||
placeholders. This is intentional: ``compress_context`` requires plain
|
||||
text for token counting and LLM summarization. The structural loss is
|
||||
acceptable because compaction only runs when the original transcript was
|
||||
already too large for the model — a summarized plain-text version is
|
||||
better than no context at all.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict):
|
||||
btype = block.get("type", "")
|
||||
if btype == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif btype == "tool_use":
|
||||
parts.append(f"[tool_use: {block.get('name', '?')}]")
|
||||
else:
|
||||
# Preserve non-text blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _flatten_tool_result_content(blocks: list) -> str:
|
||||
"""Flatten tool_result and other content blocks into plain text.
|
||||
|
||||
Handles nested tool_result structures, text blocks, and raw strings.
|
||||
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
|
||||
or where ``text`` is ``None``.
|
||||
|
||||
Like ``_flatten_assistant_content``, structured blocks (images, nested
|
||||
tool results) are reduced to text representations for compression.
|
||||
"""
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
sub_type = sub.get("type")
|
||||
if sub_type in ("image", "document"):
|
||||
# Avoid serializing base64 binary data into
|
||||
# the compaction input — use a placeholder.
|
||||
str_parts.append(f"[__{sub_type}__]")
|
||||
elif sub_type == "text" or sub.get("text") is not None:
|
||||
str_parts.append(str(sub.get("text", "")))
|
||||
else:
|
||||
str_parts.append(json.dumps(sub))
|
||||
else:
|
||||
str_parts.append(str(sub))
|
||||
else:
|
||||
str_parts.append(str(inner))
|
||||
elif isinstance(block, dict) and block.get("type") == "text":
|
||||
str_parts.append(str(block.get("text", "")))
|
||||
elif isinstance(block, dict):
|
||||
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
|
||||
# Use __prefix__ to distinguish from literal user text.
|
||||
btype = block.get("type", "unknown")
|
||||
str_parts.append(f"[__{btype}__]")
|
||||
elif isinstance(block, str):
|
||||
str_parts.append(block)
|
||||
return "\n".join(str_parts) if str_parts else ""
|
||||
|
||||
|
||||
def _transcript_to_messages(content: str) -> list[dict]:
|
||||
"""Convert JSONL transcript entries to plain message dicts for compression.
|
||||
|
||||
Parses each line of the JSONL *content*, skips strippable metadata entries
|
||||
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
|
||||
flattened ``content`` from the ``message`` field of each remaining entry.
|
||||
|
||||
Structured content blocks (``tool_use``, ``tool_result``, images) are
|
||||
flattened to plain text via ``_flatten_assistant_content`` and
|
||||
``_flatten_tool_result_content`` so that ``compress_context`` can
|
||||
perform token counting and LLM summarization on uniform strings.
|
||||
|
||||
Returns:
|
||||
A list of ``{"role": str, "content": str}`` dicts suitable for
|
||||
``compress_context``.
|
||||
"""
|
||||
messages: list[dict] = []
|
||||
for line in content.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
|
||||
"isCompactSummary"
|
||||
):
|
||||
continue
|
||||
msg = entry.get("message", {})
|
||||
role = msg.get("role", "")
|
||||
if not role:
|
||||
continue
|
||||
msg_dict: dict = {"role": role}
|
||||
raw_content = msg.get("content")
|
||||
if role == "assistant" and isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_assistant_content(raw_content)
|
||||
elif isinstance(raw_content, list):
|
||||
msg_dict["content"] = _flatten_tool_result_content(raw_content)
|
||||
else:
|
||||
msg_dict["content"] = raw_content or ""
|
||||
messages.append(msg_dict)
|
||||
return messages
|
||||
|
||||
|
||||
def _messages_to_transcript(messages: list[dict]) -> str:
|
||||
"""Convert compressed message dicts back to JSONL transcript format.
|
||||
|
||||
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
|
||||
dicts returned by ``compress_context``. Each message becomes one JSONL
|
||||
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
|
||||
``--resume`` flag can reconstruct a valid conversation tree.
|
||||
|
||||
Assistant messages are wrapped in the full ``message`` envelope
|
||||
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
|
||||
that the CLI expects. User messages use the simpler ``{role, content}``
|
||||
form.
|
||||
|
||||
Returns:
|
||||
A newline-terminated JSONL string, or an empty string if *messages*
|
||||
is empty.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
last_uuid: str = "" # root entry uses empty string, not null
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
entry_type = "assistant" if role == "assistant" else "user"
|
||||
uid = str(uuid4())
|
||||
content = msg.get("content", "")
|
||||
if role == "assistant":
|
||||
message: dict = {
|
||||
"role": "assistant",
|
||||
"model": "",
|
||||
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
|
||||
"type": ENTRY_TYPE_MESSAGE,
|
||||
"content": [{"type": "text", "text": content}] if content else [],
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"stop_sequence": None,
|
||||
}
|
||||
else:
|
||||
message = {"role": role, "content": content}
|
||||
entry = {
|
||||
"type": entry_type,
|
||||
"uuid": uid,
|
||||
"parentUuid": last_uuid,
|
||||
"message": message,
|
||||
}
|
||||
lines.append(json.dumps(entry, separators=(",", ":")))
|
||||
last_uuid = uid
|
||||
return "\n".join(lines) + "\n" if lines else ""
|
||||
|
||||
|
||||
_COMPACTION_TIMEOUT_SECONDS = 60
|
||||
_TRUNCATION_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
Uses the shared OpenAI client from ``get_openai_client()``.
|
||||
If no client is configured or the LLM call fails, falls back to
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
transcripts.
|
||||
"""
|
||||
client = get_openai_client()
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def compact_transcript(
|
||||
content: str,
|
||||
*,
|
||||
model: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> str | None:
|
||||
"""Compact an oversized JSONL transcript using LLM summarization.
|
||||
|
||||
Converts transcript entries to plain messages, runs ``compress_context``
|
||||
(the same compressor used for pre-query history), and rebuilds JSONL.
|
||||
|
||||
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
|
||||
is flattened to plain text for compression. This matches the fidelity of
|
||||
the Plan C (DB compression) fallback path, where
|
||||
``_format_conversation_context`` similarly renders tool calls as
|
||||
``You called tool: name(args)`` and results as ``Tool result: ...``.
|
||||
Neither path preserves structured API content blocks — the compacted
|
||||
context serves as text history for the LLM, which creates proper
|
||||
structured tool calls going forward.
|
||||
|
||||
Images are per-turn attachments loaded from workspace storage by file ID
|
||||
(via ``_prepare_file_attachments``), not part of the conversation history.
|
||||
They are re-attached each turn and are unaffected by compaction.
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
|
||||
See also:
|
||||
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
|
||||
lists for pre-query DB history. Both share ``compress_context()``
|
||||
but operate on different input formats (JSONL transcript entries
|
||||
here vs. ChatMessage dicts there).
|
||||
"""
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
|
||||
return None
|
||||
try:
|
||||
result = await _run_compression(messages, model, log_prefix)
|
||||
if not result.was_compacted:
|
||||
# Compressor says it's within budget, but the SDK rejected it.
|
||||
# Return None so the caller falls through to DB fallback.
|
||||
logger.warning(
|
||||
"%s Compressor reports within budget but SDK rejected — "
|
||||
"signalling failure",
|
||||
log_prefix,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
|
||||
log_prefix,
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
compacted = _messages_to_transcript(result.messages)
|
||||
if not validate_transcript(compacted):
|
||||
logger.warning("%s Compacted transcript failed validation", log_prefix)
|
||||
return None
|
||||
return compacted
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -68,7 +68,7 @@ class TranscriptBuilder:
|
||||
type=entry_type,
|
||||
uuid=data.get("uuid") or str(uuid4()),
|
||||
parentUuid=data.get("parentUuid"),
|
||||
isCompactSummary=data.get("isCompactSummary") or None,
|
||||
isCompactSummary=data.get("isCompactSummary"),
|
||||
message=data.get("message", {}),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for JSONL transcript management utilities."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,7 +9,9 @@ from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
STRIPPABLE_TYPES,
|
||||
_cli_project_dir,
|
||||
delete_transcript,
|
||||
read_cli_session_file,
|
||||
read_compacted_entries,
|
||||
strip_progress_entries,
|
||||
validate_transcript,
|
||||
@@ -290,6 +292,85 @@ class TestStripProgressEntries:
|
||||
assert asst_entry["parentUuid"] == "u1" # reparented
|
||||
|
||||
|
||||
# --- read_cli_session_file ---
|
||||
|
||||
|
||||
class TestReadCliSessionFile:
|
||||
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns None when no .jsonl files exist."""
|
||||
# Create a project dir with no jsonl files
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
assert read_cli_session_file("/fake/cwd") is None
|
||||
|
||||
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file returns the content of a single .jsonl file."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
jsonl_file = project_dir / "session.jsonl"
|
||||
jsonl_file.write_text("line1\nline2\n")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result == "line1\nline2\n"
|
||||
|
||||
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
|
||||
"""read_cli_session_file skips symlinks that escape the project dir."""
|
||||
project_dir = tmp_path / "projects" / "encoded-cwd"
|
||||
project_dir.mkdir(parents=True)
|
||||
|
||||
# Create a file outside the project dir
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
outside_file = outside / "evil.jsonl"
|
||||
outside_file.write_text("should not be read\n")
|
||||
|
||||
# Symlink from inside project_dir to outside file
|
||||
symlink = project_dir / "evil.jsonl"
|
||||
symlink.symlink_to(outside_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._cli_project_dir",
|
||||
lambda sdk_cwd: str(project_dir),
|
||||
)
|
||||
# The symlink target resolves outside project_dir, so it should be skipped
|
||||
result = read_cli_session_file("/fake/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- _cli_project_dir ---
|
||||
|
||||
|
||||
class TestCliProjectDir:
|
||||
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
projects_dir = config_dir / "projects"
|
||||
projects_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
|
||||
|
||||
# Create a symlink inside projects/ that points outside of it.
|
||||
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
|
||||
# cwd whose encoded form matches the symlink name we create.
|
||||
evil_target = tmp_path / "escaped"
|
||||
evil_target.mkdir()
|
||||
|
||||
# The encoded form of "/evil/cwd" is "-evil-cwd"
|
||||
symlink_path = projects_dir / "-evil-cwd"
|
||||
symlink_path.symlink_to(evil_target)
|
||||
|
||||
result = _cli_project_dir("/evil/cwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- delete_transcript ---
|
||||
|
||||
|
||||
@@ -301,7 +382,7 @@ class TestDeleteTranscript:
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -321,7 +402,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -339,7 +420,7 @@ class TestDeleteTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.util.workspace_storage.get_workspace_storage",
|
||||
"backend.copilot.sdk.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
@@ -819,206 +900,131 @@ class TestCompactionFlowIntegration:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cleanup_stale_project_dirs
|
||||
# _run_compression (direct tests for the 3 code paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupStaleProjectDirs:
|
||||
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
|
||||
class TestRunCompression:
|
||||
"""Direct tests for ``_run_compression`` covering all 3 code paths.
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
Paths:
|
||||
(a) No OpenAI client configured → truncation fallback immediately.
|
||||
(b) LLM success → returns LLM-compressed result.
|
||||
(c) LLM call raises → truncation fallback.
|
||||
"""
|
||||
|
||||
def _make_compress_result(self, was_compacted: bool, msgs=None):
|
||||
"""Build a minimal CompressResult-like object."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
return SimpleNamespace(
|
||||
was_compacted=was_compacted,
|
||||
messages=msgs or [{"role": "user", "content": "summary"}],
|
||||
original_token_count=500,
|
||||
token_count=100 if was_compacted else 500,
|
||||
messages_summarized=2 if was_compacted else 0,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
)
|
||||
|
||||
# Create a stale dir
|
||||
stale = projects_dir / "-tmp-copilot-old-session"
|
||||
stale.mkdir()
|
||||
# Set mtime to past the threshold
|
||||
import time
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=truncation_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale, (old_time, old_time))
|
||||
|
||||
# Create a fresh dir
|
||||
fresh = projects_dir / "-tmp-copilot-new-session"
|
||||
fresh.mkdir()
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 1
|
||||
assert not stale.exists()
|
||||
assert fresh.exists()
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
# compress_context called with client=None (truncation mode)
|
||||
call_kwargs = mock_compress.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("client") is None
|
||||
or (call_kwargs.args and call_kwargs.args[2] is None)
|
||||
or mock_compress.call_args[1].get("client") is None
|
||||
)
|
||||
assert result is truncation_result
|
||||
|
||||
# Non-copilot dir that's old
|
||||
import time
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
other = projects_dir / "some-other-project"
|
||||
other.mkdir()
|
||||
old_time = time.time() - 999999
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert other.exists()
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
new_callable=AsyncMock,
|
||||
return_value=llm_result,
|
||||
) as mock_compress,
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
# compress_context called with the real client
|
||||
assert mock_compress.called
|
||||
assert result is llm_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
)
|
||||
mock_client = MagicMock()
|
||||
call_count = [0]
|
||||
|
||||
import time
|
||||
async def _compress_side_effect(**kwargs):
|
||||
call_count[0] += 1
|
||||
if kwargs.get("client") is not None:
|
||||
raise RuntimeError("LLM timeout")
|
||||
return truncation_result
|
||||
|
||||
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
|
||||
boundary = projects_dir / "-tmp-copilot-boundary"
|
||||
boundary.mkdir()
|
||||
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
|
||||
os.utime(boundary, (boundary_time, boundary_time))
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.get_openai_client",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript.compress_context",
|
||||
side_effect=_compress_side_effect,
|
||||
),
|
||||
):
|
||||
result = await _run_compression(
|
||||
[{"role": "user", "content": "long conversation"}],
|
||||
model="test-model",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert boundary.exists()
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
# Create a regular FILE (not a dir) with the copilot pattern name
|
||||
stale_file = projects_dir / "-tmp-copilot-stale-file"
|
||||
stale_file.write_text("not a dir")
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
os.utime(stale_file, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
assert stale_file.exists()
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
removed = cleanup_stale_project_dirs()
|
||||
assert removed == 0
|
||||
|
||||
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
|
||||
# Two stale copilot dirs
|
||||
target = projects_dir / "-tmp-copilot-session-abc"
|
||||
target.mkdir()
|
||||
os.utime(target, (old_time, old_time))
|
||||
|
||||
other = projects_dir / "-tmp-copilot-session-xyz"
|
||||
other.mkdir()
|
||||
os.utime(other, (old_time, old_time))
|
||||
|
||||
# Only the target dir should be removed
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
|
||||
assert removed == 1
|
||||
assert not target.exists()
|
||||
assert other.exists() # untouched — not the current session
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
fresh = projects_dir / "-tmp-copilot-session-new"
|
||||
fresh.mkdir()
|
||||
# mtime is now — well within TTL
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
|
||||
assert removed == 0
|
||||
assert fresh.exists()
|
||||
|
||||
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.transcript._projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
|
||||
non_copilot = projects_dir / "some-other-project"
|
||||
non_copilot.mkdir()
|
||||
os.utime(non_copilot, (old_time, old_time))
|
||||
|
||||
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
|
||||
assert removed == 0
|
||||
assert non_copilot.exists()
|
||||
# compress_context called twice: once for LLM (raises), once for truncation
|
||||
assert call_count[0] == 2
|
||||
assert result is truncation_result
|
||||
|
||||
@@ -41,8 +41,7 @@ import contextlib
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Literal
|
||||
|
||||
from e2b import AsyncSandbox
|
||||
from e2b.sandbox.sandbox_api import SandboxLifecycle
|
||||
from e2b import AsyncSandbox, SandboxLifecycle
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -11,9 +10,7 @@ from pydantic import BaseModel
|
||||
from backend.copilot.context import (
|
||||
E2B_WORKDIR,
|
||||
get_current_sandbox,
|
||||
get_sdk_cwd,
|
||||
get_workspace_manager,
|
||||
is_allowed_local_path,
|
||||
resolve_sandbox_path,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
@@ -27,10 +24,6 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sentinel file_id used when a tool-result file is read directly from the local
|
||||
# host filesystem (rather than from workspace storage).
|
||||
_LOCAL_TOOL_RESULT_FILE_ID = "local"
|
||||
|
||||
|
||||
async def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
@@ -282,93 +275,6 @@ class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
content_base64: str
|
||||
|
||||
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
def _read_local_tool_result(
|
||||
path: str,
|
||||
char_offset: int,
|
||||
char_length: Optional[int],
|
||||
session_id: str,
|
||||
sdk_cwd: str | None = None,
|
||||
) -> ToolResponseBase:
|
||||
"""Read an SDK tool-result file from local disk.
|
||||
|
||||
This is a fallback for when the model mistakenly calls
|
||||
``read_workspace_file`` with an SDK tool-result path that only exists on
|
||||
the host filesystem, not in cloud workspace storage.
|
||||
|
||||
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
|
||||
regardless of what the caller has already checked.
|
||||
"""
|
||||
# TOCTOU: path validated then opened separately. Acceptable because
|
||||
# the tool-results directory is server-controlled, not user-writable.
|
||||
expanded = os.path.realpath(os.path.expanduser(path))
|
||||
# Defence-in-depth: re-check with resolved path (caller checked raw path).
|
||||
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
|
||||
return ErrorResponse(
|
||||
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
try:
|
||||
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
|
||||
# Pre-read size check prevents loading files far above the cap;
|
||||
# the remaining TOCTOU gap is acceptable for server-controlled paths.
|
||||
file_size = os.path.getsize(expanded)
|
||||
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
|
||||
return ErrorResponse(
|
||||
message=(f"File too large: {os.path.basename(path)}"),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Detect binary files: try strict UTF-8 first, fall back to
|
||||
# base64-encoding the raw bytes for binary content.
|
||||
with open(expanded, "rb") as fh:
|
||||
raw = fh.read()
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Binary file — return raw base64, ignore char_offset/char_length
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
|
||||
content_base64=base64.b64encode(raw).decode("ascii"),
|
||||
message=(
|
||||
f"Read {file_size:,} bytes (binary) from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
end = (
|
||||
char_offset + char_length if char_length is not None else len(text_content)
|
||||
)
|
||||
slice_text = text_content[char_offset:end]
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {os.path.basename(path)}", session_id=session_id
|
||||
)
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
|
||||
)
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
|
||||
name=os.path.basename(path),
|
||||
path=path,
|
||||
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
|
||||
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
|
||||
message=(
|
||||
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
|
||||
f"of {len(text_content):,} chars from local tool-result "
|
||||
f"{os.path.basename(path)}"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
@@ -627,14 +533,6 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
manager = await get_workspace_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
# Fallback: if the path is an SDK tool-result on local disk,
|
||||
# read it directly instead of failing. The model sometimes
|
||||
# calls read_workspace_file for these paths by mistake.
|
||||
sdk_cwd = get_sdk_cwd()
|
||||
if path and is_allowed_local_path(path, sdk_cwd):
|
||||
return _read_local_tool_result(
|
||||
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
|
||||
)
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
|
||||
@@ -2,25 +2,18 @@
|
||||
|
||||
import base64
|
||||
import os
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
_MAX_LOCAL_TOOL_RESULT_BYTES,
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileContentResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_read_local_tool_result,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
@@ -332,294 +325,3 @@ async def test_write_workspace_file_source_path(setup_test_data):
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_local_tool_result — local disk fallback for SDK tool-result files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
|
||||
|
||||
class TestReadLocalToolResult:
|
||||
"""Tests for _read_local_tool_result (local disk fallback)."""
|
||||
|
||||
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
|
||||
"""Create a tool-results file and return its path."""
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, filename)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(content)
|
||||
return filepath
|
||||
|
||||
def _cleanup(self, encoded: str) -> None:
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
def test_read_text_file(self):
|
||||
"""Read a UTF-8 text tool-result file."""
|
||||
encoded = "-tmp-copilot-local-read-text"
|
||||
path = self._make_tool_result(encoded, "output.txt", b"hello world")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "hello world"
|
||||
assert "text/plain" in result.mime_type
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_text_with_offset(self):
|
||||
"""Read a slice of a text file using char_offset and char_length."""
|
||||
encoded = "-tmp-copilot-local-read-offset"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 3, 4, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "DEFG"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_read_binary_file(self):
|
||||
"""Binary files are returned as raw base64."""
|
||||
encoded = "-tmp-copilot-local-read-binary"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "image.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64)
|
||||
assert decoded == binary_data
|
||||
assert "binary" in result.message
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_disallowed_path_rejected(self):
|
||||
"""Paths not under allowed directories are rejected."""
|
||||
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not allowed" in result.message.lower()
|
||||
|
||||
def test_file_not_found(self):
|
||||
"""Missing files return an error."""
|
||||
encoded = "-tmp-copilot-local-read-missing"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
path = os.path.join(tool_dir, "nope.txt")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "not found" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_file_too_large(self, monkeypatch):
|
||||
"""Files exceeding the size limit are rejected."""
|
||||
encoded = "-tmp-copilot-local-read-large"
|
||||
# Create a small file but fake os.path.getsize to return a huge value
|
||||
path = self._make_tool_result(encoded, "big.txt", b"small")
|
||||
token = _current_project_dir.set(encoded)
|
||||
monkeypatch.setattr(
|
||||
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
|
||||
)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "too large" in result.message.lower()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_beyond_file_length(self):
|
||||
"""Offset past end-of-file returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-past-eof"
|
||||
path = self._make_tool_result(encoded, "short.txt", b"abc")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 999, 10, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_zero_length_read(self):
|
||||
"""Requesting zero characters returns empty content."""
|
||||
encoded = "-tmp-copilot-local-read-zero-len"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 2, 0, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == ""
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_json_extension(self):
|
||||
"""JSON files get application/json MIME type, not hardcoded text/plain."""
|
||||
encoded = "-tmp-copilot-local-read-json"
|
||||
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "application/json"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_mime_type_from_png_extension(self):
|
||||
"""Binary .png files get image/png MIME type via mimetypes."""
|
||||
encoded = "-tmp-copilot-local-read-png-mime"
|
||||
binary_data = bytes(range(256))
|
||||
path = self._make_tool_result(encoded, "chart.png", binary_data)
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 0, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert result.mime_type == "image/png"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_explicit_sdk_cwd_parameter(self):
|
||||
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
|
||||
encoded = "-tmp-copilot-local-read-sdkcwd"
|
||||
path = self._make_tool_result(encoded, "out.txt", b"content")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Pass sdk_cwd explicitly — should still succeed because the path
|
||||
# is under SDK_PROJECTS_DIR which is always allowed.
|
||||
result = _read_local_tool_result(
|
||||
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
|
||||
)
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "content"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
def test_offset_with_no_length_reads_to_end(self):
|
||||
"""When char_length is None, read from offset to end of file."""
|
||||
encoded = "-tmp-copilot-local-read-offset-noLen"
|
||||
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
result = _read_local_tool_result(path, 5, None, "s1")
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "56789"
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
self._cleanup(encoded)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadWorkspaceFileTool fallback to _read_local_tool_result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
|
||||
"""When _resolve_file returns ErrorResponse for an allowed local path,
|
||||
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
# Create a real tool-result file on disk so the fallback can read it.
|
||||
encoded = "-tmp-copilot-fallback-test"
|
||||
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
|
||||
os.makedirs(tool_dir, exist_ok=True)
|
||||
filepath = os.path.join(tool_dir, "result.txt")
|
||||
with open(filepath, "w") as f:
|
||||
f.write("fallback content")
|
||||
|
||||
token = _current_project_dir.set(encoded)
|
||||
try:
|
||||
# Mock _resolve_file to return an ErrorResponse (simulating "file not
|
||||
# found in workspace") so the fallback branch is exercised.
|
||||
mock_resolve = AsyncMock(
|
||||
return_value=ErrorResponse(
|
||||
message="File not found at path: result.txt",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
)
|
||||
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
path=filepath,
|
||||
)
|
||||
|
||||
# Should have fallen back to _read_local_tool_result and succeeded.
|
||||
assert isinstance(result, WorkspaceFileContentResponse), (
|
||||
f"Expected fallback to local read, got {type(result).__name__}: "
|
||||
f"{getattr(result, 'message', '')}"
|
||||
)
|
||||
decoded = base64.b64decode(result.content_base64).decode("utf-8")
|
||||
assert decoded == "fallback content"
|
||||
mock_resolve.assert_awaited_once()
|
||||
finally:
|
||||
_current_project_dir.reset(token)
|
||||
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
|
||||
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
|
||||
fake_file_id = "fake-file-id-001"
|
||||
fake_content = b"workspace content"
|
||||
|
||||
# Build a minimal file_info stub that the tool's happy-path needs.
|
||||
class _FakeFileInfo:
|
||||
id = fake_file_id
|
||||
name = "result.json"
|
||||
path = "/result.json"
|
||||
mime_type = "text/plain"
|
||||
size_bytes = len(fake_content)
|
||||
|
||||
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
|
||||
|
||||
with (
|
||||
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files.get_workspace_manager",
|
||||
AsyncMock(return_value=mock_manager),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.workspace_files._read_local_tool_result"
|
||||
) as patched_local,
|
||||
):
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
result = await read_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
file_id=fake_file_id,
|
||||
)
|
||||
|
||||
# Fallback must not have been called.
|
||||
patched_local.assert_not_called()
|
||||
# Normal workspace path must have produced a content response.
|
||||
assert isinstance(result, WorkspaceFileContentResponse)
|
||||
assert base64.b64decode(result.content_base64) == fake_content
|
||||
|
||||
@@ -70,6 +70,10 @@ def _msg_tokens(msg: dict, enc) -> int:
|
||||
# Count tool result tokens
|
||||
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
# Count text block tokens (standard: "text" key, fallback: "content")
|
||||
text_val = item.get("text") or item.get("content", "")
|
||||
tool_call_tokens += _tok_len(text_val, enc)
|
||||
elif isinstance(item, dict) and "content" in item:
|
||||
# Other content types with content field
|
||||
tool_call_tokens += _tok_len(item.get("content", ""), enc)
|
||||
@@ -145,10 +149,16 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
if len(ids) <= max_tok:
|
||||
return text # nothing to do
|
||||
|
||||
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
|
||||
if max_tok < 1:
|
||||
return ""
|
||||
mid = enc.encode(" … ")
|
||||
if max_tok < 3:
|
||||
return enc.decode(ids[:max_tok])
|
||||
|
||||
# Split the allowance between the two ends:
|
||||
head = max_tok // 2 - 1 # -1 for the ellipsis
|
||||
tail = max_tok - head - 1
|
||||
mid = enc.encode(" … ")
|
||||
return enc.decode(ids[:head] + mid + ids[-tail:])
|
||||
|
||||
|
||||
@@ -545,6 +555,14 @@ async def _summarize_messages_llm(
|
||||
"- Actions taken and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
|
||||
"critical for continuing the conversation:\n"
|
||||
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
|
||||
"- Image/media file paths from tool outputs\n"
|
||||
"- URLs, API endpoints, and webhook addresses\n"
|
||||
"- Resource IDs, session IDs, and identifiers\n"
|
||||
"- Tool names that were called and their key parameters\n"
|
||||
"- Environment variables, config keys, and credentials names (not values)\n\n"
|
||||
"Include ONLY the sections below that have relevant content "
|
||||
"(skip sections with nothing to report):\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
@@ -552,7 +570,8 @@ async def _summarize_messages_llm(
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers. "
|
||||
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions.\n\n"
|
||||
"## 5. All User Messages\n"
|
||||
@@ -566,7 +585,7 @@ async def _summarize_messages_llm(
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
max_tokens=2000,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
@@ -686,11 +705,15 @@ async def compress_context(
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
"Context summarized: %d -> %d tokens, summarized %d messages",
|
||||
original_count,
|
||||
total_tokens(),
|
||||
messages_summarized,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
logger.warning(
|
||||
"Summarization failed, continuing with truncation: %s", e
|
||||
)
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user