Compare commits

..

59 Commits

Author SHA1 Message Date
Zamil Majdy
3c45687e10 Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction 2026-03-17 06:16:56 +07:00
Zamil Majdy
869743ff0e fix(copilot): replace simulated retry logic in test scenarios 4-7 with _reduce_context calls
Scenarios 4 and 6 (and parts of 5, 7) were testing Python if/for logic that
mimicked the retry tree instead of calling production code. Replaced with
direct calls to _reduce_context which exercises the actual decision path.
2026-03-16 06:28:15 +07:00
Zamil Majdy
46c35cfca6 fix(copilot): add comment explaining why seen_parents is local per-entry in strip_progress_entries
Clarifies the cycle-detection guard: seen_parents is intentionally local to
each entry's ancestry walk rather than shared across iterations, to avoid
incorrectly short-circuiting valid UUID reuse in different subtrees.
2026-03-16 06:23:59 +07:00
Zamil Majdy
8748b3e49d fix(copilot): guard against base64 data in _flatten_tool_result_content
Inner sub-blocks of type 'image' or 'document' inside tool_result content
may carry base64-encoded binary data. Previously these were serialized via
json.dumps which would include the full payload in the compaction input.
Now they emit a placeholder like [__image__] instead.
2026-03-16 06:23:02 +07:00
Zamil Majdy
5f4e5eb207 fix(copilot): remove unused transcript_factory pytest fixture from conftest
The fixture only wrapped build_test_transcript which tests import directly.
2026-03-16 06:21:59 +07:00
Zamil Majdy
2479de7ac9 fix(copilot): use AsyncIterator for _run_stream_attempt and stream_chat_completion_sdk
AsyncIterator[T] is the conventional annotation for async generator
functions whose callers use async-for only. AsyncGenerator[T, S] is
reserved for generators that need aclose() (kept for _iter_sdk_messages
whose tests call gen.aclose() directly).
2026-03-16 06:20:57 +07:00
Zamil Majdy
f4dee98508 fix(copilot): wrap write_transcript_to_tempfile with asyncio.to_thread
The function performs synchronous file I/O (open/write) which blocks the
event loop when called from async context. Both call sites in
_reduce_context and stream_chat_completion_sdk now use asyncio.to_thread.
2026-03-16 06:18:20 +07:00
Zamil Majdy
bd23caa116 fix(copilot): replace assert with RuntimeError guard for transport init check
assert is stripped in Python optimized mode (-O); replace with explicit
RuntimeError so transport initialization failures are always caught.
2026-03-16 06:16:23 +07:00
Zamil Majdy
17bbd18521 refactor(copilot): extract _run_stream_attempt + add integration tests
- Extract nested `_run_stream_attempt` to module level with `_StreamContext`
  dataclass carrying the read-only per-request variables (session, ids,
  attachments, compaction, lock) — eliminates the 330-line nested generator
  and makes the function independently testable
- Add 2 real integration tests for `stream_chat_completion_sdk` retry loop
  that patch `ClaudeSDKClient` directly and verify the actual generator
  retries on prompt-too-long (not a simulation)
- Add 3 direct tests for `_run_compression` covering: no-client → truncation,
  LLM success, and LLM fail → truncation fallback
2026-03-15 22:31:04 +07:00
Zamil Majdy
de73d89e39 fix(copilot): encode StreamStatus as SSE comment to avoid AI SDK parse error
The frontend AI SDK validates every `data:` line against a strict Zod
union of known chunk types. `"status"` is not in that union, so sending
it as `data:` causes a schema-validation throw that breaks the stream.
Override `to_sse()` to emit an SSE comment (`: status ...`) instead,
which is silently discarded by EventSource parsers while still keeping
the connection alive.
2026-03-15 14:06:44 +07:00
Zamil Majdy
29efcfb280 fix(copilot): user-friendly exhaustion error, reduce E2B timeout, add events_yielded guard test
- Replace raw SDK error text with actionable message when all retry
  attempts are exhausted ("Your conversation is too long...")
- Reduce e2b_sandbox_timeout from 600s (10 min) to 420s (7 min)
- Add TestEventsYieldedGuard tests verifying the retry loop breaks
  immediately when events have already been sent to the frontend
2026-03-15 13:50:30 +07:00
Zamil Majdy
f1151c5cc1 fix(copilot): address autogpt-reviewer conditions — rename, docstrings, cross-refs
- Rename `transcript_caused_error` to `skip_transcript_upload` to clarify
  its single responsibility (upload guard in finally block)
- Drop leading underscores from actively-read variables (`_attempt`,
  `_tried_compaction`, `_events_yielded`, `_pre_attempt_msg_count`)
- Add docstrings for `_next_msg` and `_on_stderr` (raises coverage >80%)
- Add cross-reference docstrings between the three compaction paths
  (`_compress_messages`, `compact_transcript`, `CompactionTracker`)
- Update retry_scenarios_test.py to match renamed variable
2026-03-15 13:34:24 +07:00
Zamil Majdy
11dbc08450 fix(copilot): add retry user feedback, move constant to module level, improve docstrings
- Add StreamStatus response type and emit "Optimizing conversation
  context..." before compaction retries so users get feedback during
  the 10-30s wait
- _MAX_STREAM_ATTEMPTS already at module level alongside
  _PROMPT_TOO_LONG_PATTERNS (reviewer item addressed)
- Expand docstrings for _messages_to_transcript and
  _transcript_to_messages in transcript.py
2026-03-15 04:11:07 +07:00
Zamil Majdy
bca314cfbe fix(copilot): exclude heartbeat events from _events_yielded retry guard
StreamHeartbeat events are non-substantive keep-alive signals that
clients ignore. Counting them in _events_yielded incorrectly prevented
the retry-with-compaction logic from running when a tool call took
longer than 10 seconds before a prompt-too-long error.
2026-03-15 00:17:08 +07:00
Zamil Majdy
c4a51d2804 fix(copilot): skip transcript upload when builder does not cover full session prefix
When transcript download fails or returns invalid content on a session with
prior messages, the transcript builder only contains the current turn. Uploading
it with message_count=len(session.messages) would overstate coverage, causing
_build_query_message to skip gap reconstruction on the next turn.

Add _transcript_covers_prefix flag that is set to False when download fails,
and check it before uploading to prevent partial transcripts from being stored
as if they covered the entire session history.
2026-03-14 23:48:54 +07:00
Zamil Majdy
e17e1616d9 fix(copilot): align service_helpers_test with prompt-too-long pattern removal
The max_tokens_exceeded pattern was intentionally removed from
_PROMPT_TOO_LONG_PATTERNS (too broad, matches non-context errors).
prompt_too_long_test.py correctly asserts it does NOT match, but
service_helpers_test.py still expected it to match. Fix the test
to align with the intended behavior.
2026-03-14 23:47:55 +07:00
Zamil Majdy
ca0b3cde16 fix(copilot): add max_tokens_exceeded to prompt-too-long patterns
The pattern was missing from _PROMPT_TOO_LONG_PATTERNS, causing
test_max_tokens_exceeded and test_all_patterns to fail. Also sync
test_all_patterns with the full pattern list (add input tokens exceed).
2026-03-14 23:43:07 +07:00
Zamil Majdy
045096d863 fix(copilot): flush unresolved tool calls on stream error to prevent stale frontend UI
When a mid-stream exception interrupts _run_stream_attempt, the post-stream
cleanup (tool call flushing, text-end closing) was bypassed. This left the
frontend with stale UI elements (e.g. active spinners) for tools that started
but never received completion events.

Add error-path cleanup after the retry loop: call _end_text_if_open and
_flush_unresolved_tool_calls on the adapter before yielding the StreamError,
so the frontend receives proper closure events for any in-flight tool calls.
2026-03-14 22:50:31 +07:00
Zamil Majdy
fc844fde1f fix: add truncation timeout, closure contract, shared test fixtures
- Document outer-scope variable contract in _run_stream_attempt closure
- Add 30s timeout to truncation fallback in _run_compression
- Fix _truncate_middle_tokens edge case for max_tok < 3
- Pass model parameter to compact_transcript to avoid ChatConfig() per call
- Extract _build_transcript to shared conftest.py fixture
- Fix scenario numbering gap in retry_scenarios_test.py
- Add comment explaining E2B sandbox timeout increase
2026-03-14 22:28:41 +07:00
Zamil Majdy
9642332332 Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction 2026-03-14 22:11:19 +07:00
Zamil Majdy
47d91e915f fix(backend/copilot): fix token counting, truncation edge case, and remaining logging issues
- Fix _msg_tokens not counting Anthropic text blocks (underestimated tokens)
- Fix _truncate_middle_tokens producing wrong output when max_tok < 3
- Add cycle guard to strip_progress_entries reparenting loop
- Replace bare except with logged exception in transcript metadata loading
- Convert remaining f-string logger calls to %-style
2026-03-14 22:06:48 +07:00
Zamil Majdy
df75e130da fix(copilot): clear compact event on retry, use distinct error codes
- Clear _compact_start in CompactionTracker.reset_for_query() to prevent
  spurious compaction UI events on retry after stream error
- Use "sdk_stream_error" code for non-context errors that break the loop
  immediately, reserve "all_attempts_exhausted" for when all retries are
  consumed
2026-03-14 21:56:49 +07:00
Zamil Majdy
d0fc7ed3b2 fix(copilot): code quality improvements across SDK modules
- Use lazy % logging instead of f-strings in service.py
- Move inner imports to top-level in transcript.py
- Fix redundant except (FileNotFoundError, Exception) → except Exception
- Convert entry_types loop to list comprehension
- Consolidate duplicate tool_calls iteration in compaction.py
- Remove redundant `or None` on isCompactSummary in transcript_builder.py
- Fix test mock paths after top-level import refactor
2026-03-14 21:52:39 +07:00
Zamil Majdy
9781aa93e3 fix(copilot): improve summarization prompt to retain file paths and references
- Add explicit instructions to preserve file paths, image paths, URLs,
  resource IDs, tool names, and config keys during LLM summarization
- Bump max_tokens from 1500 to 2000 for more detailed summaries
- Document compaction fidelity equivalence between Plan B and Plan C
- Add docstrings explaining tool_use structure flattening trade-off
- Fix remaining f-string logging in transcript.py to %s-style
2026-03-14 21:41:09 +07:00
Zamil Majdy
f043fa7b6a fix(copilot): increase E2B timeout to 15min, fix internal e2b import
- Increase e2b_sandbox_timeout from 300s (5min) to 900s (15min) to
  avoid interrupting long-running sandbox commands like pip install
  with compilation.
- Import SandboxLifecycle from public e2b API instead of internal
  e2b.sandbox.sandbox_api module path.
2026-03-14 21:33:11 +07:00
Zamil Majdy
ca4dad979d fix(copilot): narrow max_tokens pattern, NamedTuple for _reduce_context, add 23 unit tests
- Change "max_tokens" pattern to "max_tokens_exceeded" to avoid false
  matches on output param errors (e.g. "max_tokens must be at least 1").
- Replace 5-element tuple return from _reduce_context with ReducedContext
  NamedTuple for clarity at call sites.
- Convert remaining f-string logging to lazy % format in security_hooks
  and sanitize untrusted error field in post_tool_failure_hook.
- Add 23 unit tests for _is_prompt_too_long (exception chains, cycle
  detection, all patterns), _reduce_context (6 scenarios), and
  _iter_sdk_messages (heartbeat, propagation, cleanup).
2026-03-14 21:32:32 +07:00
Zamil Majdy
4559d13b29 fix(copilot): classify errors for retry, add compaction timeout, harden logging
- Re-introduce _is_prompt_too_long() to only retry context-size errors;
  non-context errors (network, auth, rate-limit) surface immediately
  instead of triggering unnecessary compaction.
- Add 60-second timeout on LLM compaction to prevent hung calls from
  blocking the retry path indefinitely.
- Convert f-string logging to lazy % format across security_hooks.py
  and transcript.py.
- Sanitize `trigger` field in NotificationHook to prevent log injection.
- Add docstrings explaining structural loss in _flatten_* helpers.
2026-03-14 21:22:30 +07:00
Zamil Majdy
4cc1baac54 fix(copilot): restore extracted-code comments, fix parentUuid root entry
Restore detailed inline comments that explain *why* specific patterns
are used (non-cancelling heartbeats, PostToolUse race-condition fix,
assistant entry ordering). These were shortened during extraction.

Also fix parentUuid on the root transcript entry: use empty string ""
(matching CLI convention) instead of null.
2026-03-14 21:10:01 +07:00
Zamil Majdy
9d1881d909 fix(copilot): fix stream_err UnboundLocalError, move local import to top-level
Python 3 deletes `except ... as e` bindings after the except block
exits (PEP 3110). Save the exception to a persistent variable so
the for...else exhaustion block can reference it.

Also move `from backend.copilot.config import ChatConfig` from a
local import to top-level in transcript.py.
2026-03-14 21:06:06 +07:00
Zamil Majdy
384b261e7f refactor(copilot): extract _run_stream_attempt to flatten retry loop
Move the entire streaming attempt (client lifecycle, query, message
iteration, response dispatch, post-stream cleanup) into an inner async
generator. The retry loop now becomes a clean for/else with try/except,
reducing nesting by 2-3 levels.
2026-03-14 20:35:42 +07:00
Zamil Majdy
4cc0bbf472 refactor(copilot): extract _iter_sdk_messages to flatten streaming loop
Move the heartbeat-based SDK message iteration (pending task lifecycle,
timeout, cleanup) into a standalone async generator. This removes two
levels of nesting (try/while) from the main streaming loop and moves
post-stream processing outside the `async with` block.
2026-03-14 20:17:42 +07:00
Zamil Majdy
3082f878fe refactor(copilot): extract _reduce_context helper, rename vars for clarity
Consolidate retry-setup logic into _reduce_context() to reduce
indentation in the streaming loop. Rename _query_attempt → _attempt
and _MAX_QUERY_ATTEMPTS → _MAX_STREAM_ATTEMPTS for clarity.
2026-03-14 20:00:59 +07:00
Zamil Majdy
33cd967e66 refactor(backend/copilot): rename retry variables for clarity
- _MAX_QUERY_ATTEMPTS → _MAX_STREAM_ATTEMPTS with docstring explaining
  the 3-step progression (original → compacted → no transcript)
- _query_attempt → _attempt
- _compaction_attempted → _tried_compaction
- Log message: "Retry attempt" → "Retrying with reduced context"
2026-03-14 19:56:00 +07:00
Zamil Majdy
b599858dea simplify(backend/copilot): remove error classification, always retry
Remove RetryStrategy enum, _classify_error(), and _TRANSCRIPT_ERROR_PATTERNS.
All streaming errors now follow the same retry path:
  1. Original query (with transcript)
  2. Compacted transcript (LLM summarization)
  3. No transcript (DB-message rebuild)

No need to pattern-match error messages — compaction is cheap to attempt
and harmless if it fails (falls through to DB rebuild anyway).
2026-03-14 19:39:33 +07:00
Zamil Majdy
629ecc9436 fix(backend/copilot): fix retry loop trying compaction twice, rename helpers
- Track _compaction_attempted to prevent trying compaction again after
  it already succeeded but the compacted transcript also failed
- Rename _apply_plan_b/_apply_plan_c to descriptive names:
  _retry_with_compacted_transcript / _retry_without_transcript
2026-03-14 19:35:45 +07:00
Zamil Majdy
4b92fd09c9 refactor(backend/copilot): retry on any streaming error, not just prompt-too-long
Replace the prompt-too-long-only retry with a general error retry strategy:
- RetryStrategy.COMPACT_THEN_FALLBACK: try Plan B (compact transcript),
  then Plan C (DB fallback) — used for all errors by default
- RetryStrategy.FALLBACK_ONLY: skip compaction, go straight to Plan C —
  used for transcript/JSON parse errors where compaction can't help

Extract _apply_plan_b() and _apply_plan_c() helpers to deduplicate the
inline retry setup logic. Replace _is_prompt_too_long() with
_classify_error() that determines the retry strategy.
2026-03-14 18:41:33 +07:00
Zamil Majdy
41872e003b fix(backend/copilot): use shared get_openai_client() for transcript compaction
Replace per-call AsyncOpenAI instantiation with the process-cached
get_openai_client() from backend.util.clients, consistent with
optimize_blocks.py usage.
2026-03-14 11:08:07 +07:00
Zamil Majdy
5dc8d6c848 fix(platform): harden retry logic — walk exception chains, handle write failures, add 20 edge-case tests
- _is_prompt_too_long walks __cause__/__context__ to detect wrapped errors
- write_transcript_to_tempfile failure now sets transcript_caused_error
- _run_compression catches all exceptions for truncation fallback
- _flatten_tool_result_content handles None content
- Remove circular import workaround (ChatConfig as top-level import)
- Add 20 new tests: exception chains, malformed transcripts, unicode,
  tool use/result pairs, session rollback, write failure propagation
2026-03-14 10:34:21 +07:00
Zamil Majdy
8c8e596302 fix(backend/copilot): rollback partial session messages on prompt-too-long retry
Save message count before each retry attempt and restore on prompt-too-long
to prevent duplicate assistant messages from failed attempts.
2026-03-14 10:24:16 +07:00
Zamil Majdy
ad6e2f0ca1 fix(backend/copilot): address PR review - top-level imports, non-text block placeholders, no-op compaction
- Move openai and compress_context imports to top-level in transcript.py
- Add [type] placeholders in _flatten_assistant_content and
  _flatten_tool_result_content for non-text blocks (e.g. images)
- Return None from compact_transcript when was_compacted is False
- Enhance _run_compression docstring with fallback behavior
2026-03-14 10:11:25 +07:00
Zamil Majdy
d1ef92a79a fix(platform): add retry scenario tests, add request-too-large pattern, fix compact_transcript to return None when not compacted 2026-03-14 10:10:32 +07:00
Zamil Majdy
15d36233b6 fix(platform): fix misleading log message, reset stream_completed on retry 2026-03-14 06:20:09 +07:00
Zamil Majdy
618dde9d02 fix(platform): address review - skip upload on DB fallback, simplify retry break, document local imports 2026-03-14 06:14:36 +07:00
Zamil Majdy
39c0fece87 Merge origin/dev into fix/copilot-transcript-compaction
Resolve merge conflicts in service.py, transcript.py, and transcript_test.py.
Incorporates dev's compaction tracking (emit_end_if_ready, entries_replaced,
read_compacted_entries) into the retry loop structure.
2026-03-14 06:08:07 +07:00
Zamil Majdy
41591fd76f fix(backend/copilot): add try-compact-retry for prompt-too-long errors
When the transcript exceeds the context window, the SDK now retries
with escalating strategies:
1. Use transcript as-is (normal path)
2. Compact transcript via LLM summarization, retry with --resume
3. Drop transcript entirely, fall back to DB-reconstructed context

- Add _is_prompt_too_long detector with specific patterns
- Add compact_transcript + helpers (_transcript_to_messages,
  _messages_to_transcript, _flatten_*) to transcript.py
- Remove dead delete_transcript function and its tests
- Guard finally-block transcript upload with transcript_caused_error
  to prevent re-uploading a problematic transcript
- Add 90 tests covering detection, helpers, and compaction
2026-03-14 05:54:39 +07:00
Zamil Majdy
7d95321fd9 test(backend/copilot): add e2e compaction lifecycle tests
Exercises the full service.py compaction flow end-to-end:
TranscriptBuilder load → CompactionTracker state machine →
read_compacted_entries → replace_entries → export → roundtrip.
2026-03-14 01:43:13 +07:00
Zamil Majdy
4ebc759f0a fix(backend/copilot): prevent duplicate transcript entries after mid-stream compaction
When compaction ends, replace_entries loads the CLI session file which
already contains the current message. Skip append_assistant and
append_tool_result for that iteration to avoid duplicates that could
cause 'prompt is too long' errors on subsequent turns.
2026-03-14 01:17:07 +07:00
Zamil Majdy
3e509847fd fix(platform): preserve isCompactSummary in TranscriptEntry roundtrip
Add isCompactSummary field to TranscriptEntry model so compaction
summaries survive the export→load_previous roundtrip. Without this,
exported summaries with type="summary" were stripped on reload since
"summary" is in STRIPPABLE_TYPES.

Also add integration tests simulating the full compaction flow
(load → append → compact → replace → export → reload).
2026-03-14 00:38:35 +07:00
Zamil Majdy
1023134458 fix(platform): address PR review - CompactionResult, build-then-swap, last-summary, dedup helpers
- Replace side-effect @property with CompactionResult dataclass for TOCTOU safety
- Use build-then-swap in replace_entries to prevent data loss on corrupt input
- Fix read_compacted_entries to use LAST isCompactSummary (not first)
- Guard isCompactSummary in strip_progress_entries and TranscriptBuilder
- Extract _projects_base(), _build_path_from_parts(), _build_meta_storage_path()
- Extract TranscriptBuilder._parse_entry() static method
- Sanitize transcript_path in pre_compact_hook before logging
- Update compaction_test.py and transcript_test.py for new behaviors
2026-03-14 00:02:50 +07:00
Zamil Majdy
8f0f6ced10 fix(platform): address PR #12401 review - one-shot compaction flag, path validation, tests
- Make compaction_just_ended a one-shot flag that resets after being read,
  fixing multiple mid-stream compactions not triggering replace_entries
- Add path validation to read_compacted_entries() to reject paths outside
  the CLI projects directory
- Reset _transcript_path in reset_for_query() for consistency
- Add tests for read_compacted_entries and TranscriptBuilder.replace_entries
- Preserve isCompactSummary entries during STRIPPABLE_TYPES filtering
2026-03-13 23:46:46 +07:00
Zamil Majdy
9f60fda37f fix(platform): sync TranscriptBuilder with CLI on mid-stream compaction
TranscriptBuilder accumulated ALL messages (pre + post compaction),
making it larger than the CLI's active context after mid-stream
compaction. This caused "Prompt is too long" errors on resume.

Now when the PreCompact hook fires, CompactionTracker captures the
transcript_path. When compaction ends (next message arrives), we read
the CLI session file, find the isCompactSummary entry, and replace
TranscriptBuilder's entries with the compacted entries. This ensures
TranscriptBuilder always mirrors the CLI's active context.

Removes the racy read_cli_session_file fallback from the finally block
- TranscriptBuilder is now the single source of truth for uploads.
2026-03-13 23:26:06 +07:00
Zamil Majdy
b04f806760 fix(backend/copilot): fix mock patch target in delete_transcript tests 2026-03-13 19:50:02 +07:00
Zamil Majdy
0246623337 fix(backend/copilot): address review comments on transcript handling
- Move `import shutil` to module top-level (HIGH)
- Extract `_safe_glob_jsonl` helper to shorten `read_cli_session_file` (MEDIUM)
- Inline single-use `err_str` variable (LOW)
- Use lazy `%s` formatting in all logger calls (LOW)
- Fix implicit string concatenation in log format string (NIT)
- Compute accurate entry_count when using CLI transcript path (WARNING)
2026-03-13 19:41:09 +07:00
Zamil Majdy
696f533e2e Merge remote-tracking branch 'origin/dev' into fix/copilot-transcript-compaction 2026-03-13 19:27:08 +07:00
Zamil Majdy
8c7b077753 fix(backend/copilot): handle filesystem exceptions in read_cli_session_file
Wrap p.resolve() and stat() calls in try/except to prevent unhandled
OSError/RuntimeError from propagating and silently dropping transcript
uploads when the caller falls back to TranscriptBuilder.
2026-03-13 18:29:12 +07:00
Zamil Majdy
a1f34316c7 fix(backend/copilot): replace glob with pathlib and add path traversal protection
- Replace `import glob` with `pathlib.Path.glob()` in `read_cli_session_file`
- Add symlink path traversal validation on glob results using `is_relative_to`
- Add unit tests for `read_cli_session_file` and `_cli_project_dir`
2026-03-13 18:19:13 +07:00
Zamil Majdy
152f54f33d fix(backend/copilot): set skip_transcript_upload before await delete
Move flag assignment before the await to prevent re-upload if
cancellation interrupts delete_transcript.
2026-03-13 17:16:09 +07:00
Zamil Majdy
6baeb117f7 fix(backend/copilot): use separate flag for transcript upload suppression
Use skip_transcript_upload instead of reusing use_resume to gate
transcript uploads. use_resume starts False and only becomes True
after a successful download, so gating on it prevented first-turn
transcripts from ever being uploaded (bootstrap paradox).
2026-03-13 17:06:21 +07:00
Zamil Majdy
2adeb63ebc fix(backend/copilot): use CLI session file for transcript upload to preserve compaction
The TranscriptBuilder accumulates all raw SDK stream messages including
pre-compaction content. When the CLI compacts mid-stream, the uploaded
transcript still contains the full uncompacted messages, causing
"Prompt is too long" errors on the next --resume turn.

Fix:
- Read the CLI's own session file (~/.claude/projects/<cwd>/*.jsonl)
  which reflects mid-stream compaction, instead of TranscriptBuilder
- Extract _cli_project_dir() helper, refactor cleanup_cli_project_dir
- On "Prompt is too long" error with --resume, delete the oversized
  transcript so the next turn falls back to compression-based context
2026-03-13 16:52:51 +07:00
25 changed files with 3599 additions and 1542 deletions

View File

@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=300, # 5 min safety net — explicit per-turn pause is the primary mechanism
default=420, # 7 min safety net — allows headroom for compaction retries
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",

View File

@@ -17,17 +17,8 @@ from backend.util.workspace import WorkspaceManager
if TYPE_CHECKING:
from e2b import AsyncSandbox
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
# Compiled UUID pattern for validating conversation directory names.
# Kept as a module-level constant so the security-relevant pattern is easy
# to audit in one place and avoids recompilation on every call.
_UUID_RE = re.compile(r"^[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}$", re.IGNORECASE)
# Allowed base directory for the Read tool.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Encoded project-directory name for the current session (e.g.
# "-private-tmp-copilot-<uuid>"). Set by set_execution_context() so path
@@ -44,20 +35,11 @@ _current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
def encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The Claude CLI encodes the absolute cwd as a directory name by replacing
every non-alphanumeric character with ``-``. For example
``/tmp/copilot-abc`` becomes ``-tmp-copilot-abc``.
"""
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does."""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
# Keep the private alias for internal callers (backwards compat).
_encode_cwd_for_cli = encode_cwd_for_cli
def set_execution_context(
user_id: str | None,
session: ChatSession,
@@ -118,9 +100,7 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/...``.
The SDK nests tool-results under a conversation UUID directory;
the UUID segment is validated with ``_UUID_RE``.
- Files under ``~/.claude/projects/<encoded-cwd>/tool-results/`` (SDK tool-results)
"""
if not path:
return False
@@ -139,22 +119,10 @@ def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
encoded = _current_project_dir.get("")
if encoded:
project_dir = os.path.realpath(os.path.join(SDK_PROJECTS_DIR, encoded))
# Defence-in-depth: ensure project_dir didn't escape the base.
if not project_dir.startswith(SDK_PROJECTS_DIR + os.sep):
return False
# Only allow: <encoded-cwd>/<uuid>/tool-results/<file>
# The SDK always creates a conversation UUID directory between
# the project dir and tool-results/.
if resolved.startswith(project_dir + os.sep):
relative = resolved[len(project_dir) + 1 :]
parts = relative.split(os.sep)
# Require exactly: [<uuid>, "tool-results", <file>, ...]
if (
len(parts) >= 3
and _UUID_RE.match(parts[0])
and parts[1] == "tool-results"
):
return True
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
if resolved == tool_results_dir or resolved.startswith(
tool_results_dir + os.sep
):
return True
return False

View File

@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
from backend.copilot.context import (
SDK_PROJECTS_DIR,
_SDK_PROJECTS_DIR,
_current_project_dir,
get_current_sandbox,
get_execution_context,
@@ -104,13 +104,11 @@ def test_is_allowed_local_path_no_sdk_cwd_no_project_dir():
assert not is_allowed_local_path("/tmp/some-file.txt", sdk_cwd=None)
def test_is_allowed_local_path_tool_results_with_uuid():
"""Files under <encoded-cwd>/<uuid>/tool-results/ are allowed."""
def test_is_allowed_local_path_tool_results_dir():
"""Files under the tool-results directory for the current project are allowed."""
encoded = "test-encoded-dir"
conv_uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results", "output.txt"
)
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
path = os.path.join(tool_results_dir, "output.txt")
_current_project_dir.set(encoded)
try:
@@ -119,22 +117,10 @@ def test_is_allowed_local_path_tool_results_with_uuid():
_current_project_dir.set("")
def test_is_allowed_local_path_tool_results_without_uuid_rejected():
"""Direct <encoded-cwd>/tool-results/ (no UUID) is rejected."""
encoded = "test-encoded-dir"
path = os.path.join(SDK_PROJECTS_DIR, encoded, "tool-results", "output.txt")
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
"""A path adjacent to tool-results/ but not inside it is rejected."""
encoded = "test-encoded-dir"
sibling_path = os.path.join(SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
sibling_path = os.path.join(_SDK_PROJECTS_DIR, encoded, "other-dir", "file.txt")
_current_project_dir.set(encoded)
try:
@@ -143,21 +129,6 @@ def test_is_allowed_local_path_sibling_of_tool_results_is_rejected():
_current_project_dir.set("")
def test_is_allowed_local_path_valid_uuid_wrong_segment_name_rejected():
"""A valid UUID dir but non-'tool-results' second segment is rejected."""
encoded = "test-encoded-dir"
uuid_str = "12345678-1234-5678-9abc-def012345678"
path = os.path.join(
SDK_PROJECTS_DIR, encoded, uuid_str, "not-tool-results", "output.txt"
)
_current_project_dir.set(encoded)
try:
assert not is_allowed_local_path(path, sdk_cwd=None)
finally:
_current_project_dir.set("")
# ---------------------------------------------------------------------------
# resolve_sandbox_path
# ---------------------------------------------------------------------------

View File

@@ -152,13 +152,6 @@ def _build_storage_supplement(
### File persistence
Important files (code, configs, outputs) should be saved to workspace to ensure they persist.
### SDK tool-result files
When tool outputs are large, the SDK truncates them and saves the full output to
a local file under `~/.claude/projects/.../tool-results/`. To read these files,
always use `read_file` or `Read` (NOT `read_workspace_file`).
`read_workspace_file` reads from cloud workspace storage, where SDK
tool-results are NOT stored.
{_SHARED_TOOL_NOTES}"""

View File

@@ -43,6 +43,7 @@ class ResponseType(str, Enum):
ERROR = "error"
USAGE = "usage"
HEARTBEAT = "heartbeat"
STATUS = "status"
class StreamBaseResponse(BaseModel):
@@ -232,3 +233,26 @@ class StreamHeartbeat(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE comment format to keep connection alive."""
return ": heartbeat\n\n"
class StreamStatus(StreamBaseResponse):
"""Transient status notification shown to the user during long operations.
Used to provide feedback when the backend performs behind-the-scenes work
(e.g., compacting conversation context on a retry) that would otherwise
leave the user staring at an unexplained pause.
"""
type: ResponseType = ResponseType.STATUS
message: str = Field(..., description="Human-readable status message")
def to_sse(self) -> str:
"""Encode as an SSE comment so the AI SDK stream parser ignores it.
The frontend AI SDK validates every ``data:`` line against a strict
Zod union of known chunk types. ``"status"`` is not in that union,
so sending it as ``data:`` would cause a schema-validation error that
breaks the entire stream. Using an SSE comment (``:``) keeps the
connection alive and is silently discarded by ``EventSource`` parsers.
"""
return f": status {self.message}\n\n"

View File

@@ -12,6 +12,7 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
@@ -119,14 +120,12 @@ def filter_compaction_messages(
filtered: list[ChatMessage] = []
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
real_calls: list[dict[str, Any]] = []
for tc in msg.tool_calls:
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
compaction_ids.add(tc.get("id", ""))
real_calls = [
tc
for tc in msg.tool_calls
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
]
else:
real_calls.append(tc)
if not real_calls and not msg.content:
continue
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
@@ -222,6 +221,7 @@ class CompactionTracker:
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._compact_start.clear()
self._done = False
self._start_emitted = False
self._tool_call_id = ""

View File

@@ -0,0 +1,41 @@
"""Shared test fixtures for copilot SDK tests."""
from __future__ import annotations
from uuid import uuid4
from backend.util import json
def build_test_transcript(pairs: list[tuple[str, str]]) -> str:
"""Build a minimal valid JSONL transcript from (role, content) pairs.
Use this helper in any copilot SDK test that needs a well-formed
transcript without hitting the real storage layer.
"""
lines: list[str] = []
last_uuid: str | None = None
for role, content in pairs:
uid = str(uuid4())
entry_type = "assistant" if role == "assistant" else "user"
msg: dict = {"role": role, "content": content}
if role == "assistant":
msg.update(
{
"model": "",
"id": f"msg_{uid[:8]}",
"type": "message",
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
}
)
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": msg,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n"

View File

@@ -26,41 +26,6 @@ from backend.copilot.context import (
logger = logging.getLogger(__name__)
async def _check_sandbox_symlink_escape(
sandbox: Any,
parent: str,
) -> str | None:
"""Resolve the canonical parent path inside the sandbox to detect symlink escapes.
``normpath`` (used by ``resolve_sandbox_path``) only normalises the string;
``readlink -f`` follows actual symlinks on the sandbox filesystem.
Returns the canonical parent path, or ``None`` if the path escapes
``E2B_WORKDIR``.
Note: There is an inherent TOCTOU window between this check and the
subsequent ``sandbox.files.write()``. A symlink could theoretically be
replaced between the two operations. This is acceptable in the E2B
sandbox model since the sandbox is single-user and ephemeral.
"""
canonical_res = await sandbox.commands.run(
f"readlink -f {shlex.quote(parent or E2B_WORKDIR)}",
cwd=E2B_WORKDIR,
timeout=5,
)
canonical_parent = (canonical_res.stdout or "").strip()
if (
canonical_res.exit_code != 0
or not canonical_parent
or (
canonical_parent != E2B_WORKDIR
and not canonical_parent.startswith(E2B_WORKDIR + "/")
)
):
return None
return canonical_parent
def _get_sandbox():
return get_current_sandbox()
@@ -141,10 +106,6 @@ async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
@@ -169,12 +130,6 @@ async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
return result
sandbox, remote = result
parent = os.path.dirname(remote)
canonical_parent = await _check_sandbox_symlink_escape(sandbox, parent)
if canonical_parent is None:
return _mcp(f"Path must be within {E2B_WORKDIR}: {parent}", error=True)
remote = os.path.join(canonical_parent, os.path.basename(remote))
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")

View File

@@ -4,19 +4,15 @@ Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import shutil
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from backend.copilot.context import E2B_WORKDIR, SDK_PROJECTS_DIR, _current_project_dir
from backend.copilot.context import _current_project_dir
from .e2b_file_tools import _read_local, resolve_sandbox_path
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
from .e2b_file_tools import (
_check_sandbox_symlink_escape,
_read_local,
resolve_sandbox_path,
)
# ---------------------------------------------------------------------------
# resolve_sandbox_path — sandbox path normalisation & boundary enforcement
@@ -25,48 +21,46 @@ from .e2b_file_tools import (
class TestResolveSandboxPath:
def test_relative_path_resolved(self):
assert resolve_sandbox_path("src/main.py") == f"{E2B_WORKDIR}/src/main.py"
assert resolve_sandbox_path("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert (
resolve_sandbox_path(f"{E2B_WORKDIR}/file.txt") == f"{E2B_WORKDIR}/file.txt"
)
assert resolve_sandbox_path("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert resolve_sandbox_path(E2B_WORKDIR) == E2B_WORKDIR
assert resolve_sandbox_path("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert resolve_sandbox_path("./README.md") == f"{E2B_WORKDIR}/README.md"
assert resolve_sandbox_path("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
resolve_sandbox_path(f"{E2B_WORKDIR}/../../etc/passwd")
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match=f"must be within {E2B_WORKDIR}"):
with pytest.raises(ValueError, match="must be within /home/user"):
resolve_sandbox_path("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert resolve_sandbox_path("a/b/c/d/e.txt") == f"{E2B_WORKDIR}/a/b/c/d/e.txt"
assert resolve_sandbox_path("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert resolve_sandbox_path("src/") == f"{E2B_WORKDIR}/src"
assert resolve_sandbox_path("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within E2B_WORKDIR is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == f"{E2B_WORKDIR}/a/c.txt"
"""Path that resolves back within /home/user is allowed."""
assert resolve_sandbox_path("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
@@ -79,13 +73,9 @@ class TestResolveSandboxPath:
class TestReadLocal:
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file under <encoded>/<uuid>/tool-results/."""
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
@@ -117,9 +107,7 @@ class TestReadLocal:
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(
SDK_PROJECTS_DIR, encoded, self._CONV_UUID, "tool-results"
)
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
@@ -129,7 +117,7 @@ class TestReadLocal:
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
@@ -164,66 +152,3 @@ class TestReadLocal:
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True
# ---------------------------------------------------------------------------
# _check_sandbox_symlink_escape — symlink escape detection
# ---------------------------------------------------------------------------
def _make_sandbox(stdout: str, exit_code: int = 0) -> SimpleNamespace:
"""Build a minimal sandbox mock whose commands.run returns a fixed result."""
run_result = SimpleNamespace(stdout=stdout, exit_code=exit_code)
commands = SimpleNamespace(run=AsyncMock(return_value=run_result))
return SimpleNamespace(commands=commands)
class TestCheckSandboxSymlinkEscape:
@pytest.mark.asyncio
async def test_canonical_path_within_workdir_returns_path(self):
"""When readlink -f resolves to a path inside E2B_WORKDIR, returns it."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/src\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result == f"{E2B_WORKDIR}/src"
@pytest.mark.asyncio
async def test_workdir_itself_returns_workdir(self):
"""When readlink -f resolves to E2B_WORKDIR exactly, returns E2B_WORKDIR."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, E2B_WORKDIR)
assert result == E2B_WORKDIR
@pytest.mark.asyncio
async def test_symlink_escape_returns_none(self):
"""When readlink -f resolves outside E2B_WORKDIR (symlink escape), returns None."""
sandbox = _make_sandbox(stdout="/etc\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/evil")
assert result is None
@pytest.mark.asyncio
async def test_nonzero_exit_code_returns_none(self):
"""A non-zero exit code from readlink -f returns None."""
sandbox = _make_sandbox(stdout="", exit_code=1)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_empty_stdout_returns_none(self):
"""Empty stdout from readlink (e.g. path doesn't exist yet) returns None."""
sandbox = _make_sandbox(stdout="", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/src")
assert result is None
@pytest.mark.asyncio
async def test_prefix_collision_returns_none(self):
"""A path prefixed with E2B_WORKDIR but not within it is rejected."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}-evil\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}-evil")
assert result is None
@pytest.mark.asyncio
async def test_deeply_nested_path_within_workdir(self):
"""Deep nested paths inside E2B_WORKDIR are allowed."""
sandbox = _make_sandbox(stdout=f"{E2B_WORKDIR}/a/b/c/d\n", exit_code=0)
result = await _check_sandbox_symlink_escape(sandbox, f"{E2B_WORKDIR}/a/b/c/d")
assert result == f"{E2B_WORKDIR}/a/b/c/d"

View File

@@ -0,0 +1,552 @@
"""Tests for retry logic and transcript compaction helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _is_prompt_too_long
from .transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
# ---------------------------------------------------------------------------
# _flatten_assistant_content
# ---------------------------------------------------------------------------
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_tool_use_blocks(self):
blocks = [{"type": "tool_use", "name": "read_file", "input": {}}]
assert _flatten_assistant_content(blocks) == "[tool_use: read_file]"
def test_mixed_blocks(self):
blocks = [
{"type": "text", "text": "Let me read that."},
{"type": "tool_use", "name": "Read", "input": {"path": "/foo"}},
]
result = _flatten_assistant_content(blocks)
assert "Let me read that." in result
assert "[tool_use: Read]" in result
def test_raw_strings(self):
assert _flatten_assistant_content(["hello", "world"]) == "hello\nworld"
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [
{"type": "text", "text": "See this image:"},
{"type": "image", "source": {"type": "base64", "data": "..."}},
]
result = _flatten_assistant_content(blocks)
assert "See this image:" in result
assert "[__image__]" in result
def test_empty(self):
assert _flatten_assistant_content([]) == ""
# ---------------------------------------------------------------------------
# _flatten_tool_result_content
# ---------------------------------------------------------------------------
class TestFlattenToolResultContent:
def test_tool_result_with_text(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "123",
"content": [{"type": "text", "text": "file contents here"}],
}
]
assert _flatten_tool_result_content(blocks) == "file contents here"
def test_tool_result_with_string_content(self):
blocks = [{"type": "tool_result", "tool_use_id": "123", "content": "ok"}]
assert _flatten_tool_result_content(blocks) == "ok"
def test_text_block(self):
blocks = [{"type": "text", "text": "plain text"}]
assert _flatten_tool_result_content(blocks) == "plain text"
def test_raw_string(self):
assert _flatten_tool_result_content(["raw"]) == "raw"
def test_tool_result_with_none_content(self):
"""tool_result with content=None should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
assert _flatten_tool_result_content(blocks) == ""
def test_tool_result_with_empty_list_content(self):
"""tool_result with content=[] should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
assert _flatten_tool_result_content(blocks) == ""
def test_empty(self):
assert _flatten_tool_result_content([]) == ""
def test_nested_dict_without_text(self):
"""Dict blocks without text key use json.dumps fallback."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "x",
"content": [{"type": "image", "source": "data:..."}],
}
]
result = _flatten_tool_result_content(blocks)
assert "image" in result # json.dumps fallback
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
result = _flatten_tool_result_content(blocks)
assert "[__image__]" in result
# ---------------------------------------------------------------------------
# _transcript_to_messages
# ---------------------------------------------------------------------------
def _make_entry(entry_type: str, role: str, content: str | list, **kwargs) -> str:
"""Build a JSONL line for testing."""
uid = str(uuid4())
msg: dict = {"role": role, "content": content}
msg.update(kwargs)
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": None,
"message": msg,
}
return json.dumps(entry, separators=(",", ":"))
class TestTranscriptToMessages:
def test_basic_roundtrip(self):
lines = [
_make_entry("user", "user", "Hello"),
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "Hello"}
assert messages[1] == {"role": "assistant", "content": "Hi"}
def test_skips_strippable_types(self):
"""Progress and metadata entries are excluded."""
lines = [
_make_entry("user", "user", "Hello"),
json.dumps(
{
"type": "progress",
"uuid": str(uuid4()),
"parentUuid": None,
"message": {"role": "assistant", "content": "..."},
}
),
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_empty_content(self):
assert _transcript_to_messages("") == []
def test_tool_result_content(self):
"""User entries with tool_result content blocks are flattened."""
lines = [
_make_entry(
"user",
"user",
[
{
"type": "tool_result",
"tool_use_id": "123",
"content": "tool output",
}
],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 1
assert messages[0]["content"] == "tool output"
def test_malformed_json_lines_skipped(self):
"""Malformed JSON lines in transcript are silently skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"this is not valid json",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_empty_lines_skipped(self):
"""Empty lines and whitespace-only lines are skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"",
" ",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_unicode_content_preserved(self):
"""Unicode characters survive transcript roundtrip."""
lines = [
_make_entry("user", "user", "Hello 你好 🌍"),
_make_entry(
"assistant",
"assistant",
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert messages[0]["content"] == "Hello 你好 🌍"
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
def test_entry_without_role_skipped(self):
"""Entries with missing role in message are skipped."""
entry_no_role = json.dumps(
{
"type": "user",
"uuid": str(uuid4()),
"parentUuid": None,
"message": {"content": "no role here"},
}
)
lines = [
entry_no_role,
_make_entry("user", "user", "Hello"),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 1
assert messages[0]["content"] == "Hello"
def test_tool_use_and_result_pairs(self):
"""Tool use + tool result pairs are properly flattened."""
lines = [
_make_entry(
"assistant",
"assistant",
[
{"type": "text", "text": "Let me check."},
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
],
),
_make_entry(
"user",
"user",
[
{
"type": "tool_result",
"tool_use_id": "abc",
"content": [{"type": "text", "text": "file contents"}],
}
],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert "Let me check." in messages[0]["content"]
assert "[tool_use: read_file]" in messages[0]["content"]
assert messages[1]["content"] == "file contents"
# ---------------------------------------------------------------------------
# _messages_to_transcript
# ---------------------------------------------------------------------------
class TestMessagesToTranscript:
def test_produces_valid_jsonl(self):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
result = _messages_to_transcript(messages)
lines = result.strip().split("\n")
assert len(lines) == 2
for line in lines:
parsed = json.loads(line)
assert "type" in parsed
assert "uuid" in parsed
assert "message" in parsed
def test_assistant_has_proper_structure(self):
messages = [{"role": "assistant", "content": "Hello"}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["type"] == "assistant"
msg = entry["message"]
assert msg["role"] == "assistant"
assert msg["type"] == "message"
assert msg["stop_reason"] == "end_turn"
assert isinstance(msg["content"], list)
assert msg["content"][0]["type"] == "text"
def test_user_has_plain_content(self):
messages = [{"role": "user", "content": "Hi"}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["type"] == "user"
assert entry["message"]["content"] == "Hi"
def test_parent_uuid_chain(self):
messages = [
{"role": "user", "content": "A"},
{"role": "assistant", "content": "B"},
{"role": "user", "content": "C"},
]
result = _messages_to_transcript(messages)
lines = result.strip().split("\n")
entries = [json.loads(line) for line in lines]
assert entries[0]["parentUuid"] == ""
assert entries[1]["parentUuid"] == entries[0]["uuid"]
assert entries[2]["parentUuid"] == entries[1]["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_output_is_valid_transcript(self):
"""Output should pass validate_transcript if it has assistant entries."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
def test_roundtrip_to_messages(self):
"""Messages → transcript → messages preserves structure."""
original = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
{"role": "user", "content": "How are you?"},
]
transcript = _messages_to_transcript(original)
restored = _transcript_to_messages(transcript)
assert len(restored) == len(original)
for orig, rest in zip(original, restored):
assert orig["role"] == rest["role"]
assert orig["content"] == rest["content"]
# ---------------------------------------------------------------------------
# compact_transcript
# ---------------------------------------------------------------------------
class TestCompactTranscript:
@pytest.mark.asyncio
async def test_too_few_messages_returns_none(self):
"""compact_transcript returns None when transcript has < 2 messages."""
transcript = _build_transcript([("user", "Hello")])
with patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_not_compacted(self):
"""When compress_context says no compaction needed, returns None.
The compressor couldn't reduce it, so retrying with the same
content would fail identically."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi there"),
]
)
mock_result = type(
"CompressResult",
(),
{
"was_compacted": False,
"messages": [],
"original_token_count": 100,
"token_count": 100,
"messages_summarized": 0,
"messages_dropped": 0,
},
)()
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
async def test_returns_compacted_transcript(self):
"""When compaction succeeds, returns a valid compacted transcript."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi"),
("user", "More"),
("assistant", "Details"),
]
)
compacted_msgs = [
{"role": "user", "content": "[summary]"},
{"role": "assistant", "content": "Summarized response"},
]
mock_result = type(
"CompressResult",
(),
{
"was_compacted": True,
"messages": compacted_msgs,
"original_token_count": 500,
"token_count": 100,
"messages_summarized": 2,
"messages_dropped": 0,
},
)()
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
assert len(msgs) == 2
assert msgs[1]["content"] == "Summarized response"
@pytest.mark.asyncio
async def test_returns_none_on_compression_failure(self):
"""When _run_compression raises, returns None."""
transcript = _build_transcript(
[
("user", "Hello"),
("assistant", "Hi"),
]
)
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("LLM unavailable"),
),
):
result = await compact_transcript(transcript, model="test-model")
assert result is None
# ---------------------------------------------------------------------------
# _is_prompt_too_long
# ---------------------------------------------------------------------------
class TestIsPromptTooLong:
"""Unit tests for _is_prompt_too_long pattern matching."""
def test_prompt_is_too_long(self):
err = RuntimeError("prompt is too long for model context")
assert _is_prompt_too_long(err) is True
def test_request_too_large(self):
err = Exception("request too large: 250000 tokens")
assert _is_prompt_too_long(err) is True
def test_maximum_context_length(self):
err = ValueError("maximum context length exceeded")
assert _is_prompt_too_long(err) is True
def test_context_length_exceeded(self):
err = Exception("context_length_exceeded")
assert _is_prompt_too_long(err) is True
def test_input_tokens_exceed(self):
err = Exception("input tokens exceed the max_tokens limit")
assert _is_prompt_too_long(err) is True
def test_input_is_too_long(self):
err = Exception("input is too long for the model")
assert _is_prompt_too_long(err) is True
def test_content_length_exceeds(self):
err = Exception("content length exceeds maximum")
assert _is_prompt_too_long(err) is True
def test_unrelated_error_returns_false(self):
err = RuntimeError("network timeout")
assert _is_prompt_too_long(err) is False
def test_auth_error_returns_false(self):
err = Exception("authentication failed: invalid API key")
assert _is_prompt_too_long(err) is False
def test_chained_exception_detected(self):
"""Prompt-too-long error wrapped in another exception is detected."""
inner = RuntimeError("prompt is too long")
outer = Exception("SDK error")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_case_insensitive(self):
err = Exception("PROMPT IS TOO LONG")
assert _is_prompt_too_long(err) is True
def test_old_max_tokens_exceeded_not_matched(self):
"""The old broad 'max_tokens_exceeded' pattern was removed.
Only 'input tokens exceed' should match now."""
err = Exception("max_tokens_exceeded")
assert _is_prompt_too_long(err) is False

View File

@@ -226,7 +226,7 @@ class SDKResponseAdapter:
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
logger.debug("Unhandled SDK message type: %s", type(sdk_message).__name__)
return responses

File diff suppressed because it is too large Load Diff

View File

@@ -42,7 +42,7 @@ def _validate_workspace_path(
Delegates to :func:`is_allowed_local_path` which permits:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/<uuid>/tool-results/``)
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
@@ -52,7 +52,7 @@ def _validate_workspace_path(
if is_allowed_local_path(path, sdk_cwd):
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning("Blocked %s outside workspace: %s", tool_name, path)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -71,7 +71,7 @@ def _validate_tool_access(
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
logger.warning("Blocked tool access attempt: %s", tool_name)
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
@@ -89,7 +89,9 @@ def _validate_tool_access(
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
"Blocked dangerous pattern in tool input: %s in %s",
pattern,
tool_name,
)
return _deny(
"[SECURITY] Input contains a blocked pattern. "
@@ -111,7 +113,9 @@ def _validate_user_isolation(
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
logger.warning(
"Blocked path traversal attempt: %s by user %s", path, user_id
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
@@ -170,7 +174,7 @@ def create_security_hooks(
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
logger.info("[SDK] Blocked background Task, user=%s", user_id)
return cast(
SyncHookJSONOutput,
_deny(
@@ -181,7 +185,9 @@ def create_security_hooks(
)
if len(task_tool_use_ids) >= max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
"[SDK] Task limit reached (%d), user=%s",
max_subtasks,
user_id,
)
return cast(
SyncHookJSONOutput,
@@ -212,7 +218,7 @@ def create_security_hooks(
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
logger.debug("[SDK] Tool start: %s, user=%s", tool_name, user_id)
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
@@ -282,8 +288,11 @@ def create_security_hooks(
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
"[SDK] Tool failed: %s, error=%s, user=%s, tool_use_id=%s",
tool_name,
str(error).replace("\n", "").replace("\r", ""),
user_id,
tool_use_id,
)
_release_task_slot(tool_name, tool_use_id)
@@ -301,20 +310,19 @@ def create_security_hooks(
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
# Sanitize untrusted input before logging to prevent log injection
trigger = (
str(input_data.get("trigger", "auto"))
.replace("\n", "")
.replace("\r", "")
)
transcript_path = (
str(input_data.get("transcript_path", ""))
.replace("\n", "")
.replace("\r", "")
)
logger.info(
"[SDK] Context compaction triggered: %s, user=%s, "
"transcript_path=%s",
"[SDK] Context compaction triggered: %s, user=%s, transcript_path=%s",
trigger,
user_id,
transcript_path,

View File

@@ -122,7 +122,7 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/a1b2c3d4-e5f6-7890-abcd-ef1234567890/tool-results/12345.txt"
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,283 @@
"""Unit tests for extracted service helpers.
Covers ``_is_prompt_too_long``, ``_reduce_context``, ``_iter_sdk_messages``,
and the ``ReducedContext`` named tuple.
"""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, patch
import pytest
from .conftest import build_test_transcript as _build_transcript
from .service import (
ReducedContext,
_is_prompt_too_long,
_iter_sdk_messages,
_reduce_context,
)
# ---------------------------------------------------------------------------
# _is_prompt_too_long
# ---------------------------------------------------------------------------
class TestIsPromptTooLong:
def test_direct_match(self) -> None:
assert _is_prompt_too_long(Exception("prompt is too long")) is True
def test_case_insensitive(self) -> None:
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
def test_no_match(self) -> None:
assert _is_prompt_too_long(Exception("network timeout")) is False
def test_request_too_large(self) -> None:
assert _is_prompt_too_long(Exception("request too large for model")) is True
def test_context_length_exceeded(self) -> None:
assert _is_prompt_too_long(Exception("context_length_exceeded")) is True
def test_max_tokens_exceeded_not_matched(self) -> None:
"""'max_tokens_exceeded' is intentionally excluded (too broad)."""
assert _is_prompt_too_long(Exception("max_tokens_exceeded")) is False
def test_max_tokens_config_error_no_match(self) -> None:
"""'max_tokens must be at least 1' should NOT match."""
assert _is_prompt_too_long(Exception("max_tokens must be at least 1")) is False
def test_chained_cause(self) -> None:
inner = Exception("prompt is too long")
outer = RuntimeError("SDK error")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_chained_context(self) -> None:
inner = Exception("request too large")
outer = RuntimeError("wrapped")
outer.__context__ = inner
assert _is_prompt_too_long(outer) is True
def test_deep_chain(self) -> None:
bottom = Exception("maximum context length")
middle = RuntimeError("middle")
middle.__cause__ = bottom
top = ValueError("top")
top.__cause__ = middle
assert _is_prompt_too_long(top) is True
def test_chain_no_match(self) -> None:
inner = Exception("rate limit exceeded")
outer = RuntimeError("wrapped")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is False
def test_cycle_detection(self) -> None:
"""Exception chain with a cycle should not infinite-loop."""
a = Exception("error a")
b = Exception("error b")
a.__cause__ = b
b.__cause__ = a # cycle
assert _is_prompt_too_long(a) is False
def test_all_patterns(self) -> None:
patterns = [
"prompt is too long",
"request too large",
"maximum context length",
"context_length_exceeded",
"input tokens exceed",
"input is too long",
"content length exceeds",
]
for pattern in patterns:
assert _is_prompt_too_long(Exception(pattern)) is True, pattern
# ---------------------------------------------------------------------------
# _reduce_context
# ---------------------------------------------------------------------------
class TestReduceContext:
@pytest.mark.asyncio
async def test_first_retry_compaction_success(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
with (
patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=compacted,
),
patch(
"backend.copilot.sdk.service.validate_transcript",
return_value=True,
),
patch(
"backend.copilot.sdk.service.write_transcript_to_tempfile",
return_value="/tmp/resume.jsonl",
),
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert isinstance(ctx, ReducedContext)
assert ctx.use_resume is True
assert ctx.resume_file == "/tmp/resume.jsonl"
assert ctx.transcript_lost is False
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_compaction_fails_drops_transcript(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
with patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=None,
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.use_resume is False
assert ctx.resume_file is None
assert ctx.transcript_lost is True
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_already_tried_compaction_skips(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
ctx = await _reduce_context(transcript, True, "sess-123", "/tmp/cwd", "[test]")
assert ctx.use_resume is False
assert ctx.transcript_lost is True
assert ctx.tried_compaction is True
@pytest.mark.asyncio
async def test_empty_transcript_drops(self) -> None:
ctx = await _reduce_context("", False, "sess-123", "/tmp/cwd", "[test]")
assert ctx.use_resume is False
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_compaction_returns_same_content_drops(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
with patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=transcript, # same content
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.transcript_lost is True
@pytest.mark.asyncio
async def test_write_tempfile_fails_drops(self) -> None:
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
with (
patch(
"backend.copilot.sdk.service.compact_transcript",
new_callable=AsyncMock,
return_value=compacted,
),
patch(
"backend.copilot.sdk.service.validate_transcript",
return_value=True,
),
patch(
"backend.copilot.sdk.service.write_transcript_to_tempfile",
return_value=None,
),
):
ctx = await _reduce_context(
transcript, False, "sess-123", "/tmp/cwd", "[test]"
)
assert ctx.transcript_lost is True
# ---------------------------------------------------------------------------
# _iter_sdk_messages
# ---------------------------------------------------------------------------
class TestIterSdkMessages:
@pytest.mark.asyncio
async def test_yields_messages(self) -> None:
messages = ["msg1", "msg2", "msg3"]
client = AsyncMock()
async def _fake_receive() -> AsyncGenerator[str]:
for m in messages:
yield m
client.receive_response = _fake_receive
result = [msg async for msg in _iter_sdk_messages(client)]
assert result == messages
@pytest.mark.asyncio
async def test_heartbeat_on_timeout(self) -> None:
"""Yields None when asyncio.wait times out."""
client = AsyncMock()
received: list = []
async def _slow_receive() -> AsyncGenerator[str]:
await asyncio.sleep(100) # never completes
yield "never" # pragma: no cover — unreachable, yield makes this an async generator
client.receive_response = _slow_receive
with patch("backend.copilot.sdk.service._HEARTBEAT_INTERVAL", 0.01):
count = 0
async for msg in _iter_sdk_messages(client):
received.append(msg)
count += 1
if count >= 3:
break
assert all(m is None for m in received)
@pytest.mark.asyncio
async def test_exception_propagates(self) -> None:
client = AsyncMock()
async def _error_receive() -> AsyncGenerator[str]:
raise RuntimeError("SDK crash")
yield # pragma: no cover — unreachable, yield makes this an async generator
client.receive_response = _error_receive
with pytest.raises(RuntimeError, match="SDK crash"):
async for _ in _iter_sdk_messages(client):
pass
@pytest.mark.asyncio
async def test_task_cleanup_on_break(self) -> None:
"""Pending task is cancelled when generator is closed."""
client = AsyncMock()
async def _slow_receive() -> AsyncGenerator[str]:
yield "first"
await asyncio.sleep(100)
yield "second"
client.receive_response = _slow_receive
gen = _iter_sdk_messages(client)
first = await gen.__anext__()
assert first == "first"
await gen.aclose() # should cancel pending task cleanly

View File

@@ -288,90 +288,3 @@ class TestPromptSupplement:
# Count how many times this tool appears as a bullet point
count = docs.count(f"- **`{tool_name}`**")
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
# ---------------------------------------------------------------------------
# _cleanup_sdk_tool_results — orchestration + rate-limiting
# ---------------------------------------------------------------------------
class TestCleanupSdkToolResults:
"""Tests for _cleanup_sdk_tool_results orchestration and sweep rate-limiting."""
# All valid cwds must start with /tmp/copilot- (the _SDK_CWD_PREFIX).
_CWD_PREFIX = "/tmp/copilot-"
@pytest.mark.asyncio
async def test_removes_cwd_directory(self):
"""Cleanup removes the session working directory."""
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-cleanup-remove"
os.makedirs(cwd, exist_ok=True)
with patch("backend.copilot.sdk.service.cleanup_stale_project_dirs"):
import backend.copilot.sdk.service as svc_mod
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
assert not os.path.exists(cwd)
@pytest.mark.asyncio
async def test_sweep_runs_when_interval_elapsed(self):
"""cleanup_stale_project_dirs is called when 5-minute interval has elapsed."""
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-elapsed"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to a time far in the past
svc_mod._last_sweep_time = 0.0
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_called_once()
@pytest.mark.asyncio
async def test_sweep_skipped_within_interval(self):
"""cleanup_stale_project_dirs is NOT called when within 5-minute interval."""
import time
import backend.copilot.sdk.service as svc_mod
from .service import _cleanup_sdk_tool_results
cwd = "/tmp/copilot-test-sweep-ratelimit"
os.makedirs(cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
# Set last sweep to now — interval not elapsed
svc_mod._last_sweep_time = time.time()
await _cleanup_sdk_tool_results(cwd)
mock_sweep.assert_not_called()
@pytest.mark.asyncio
async def test_rejects_path_outside_prefix(self, tmp_path):
"""Cleanup rejects a cwd that does not start with the expected prefix."""
from .service import _cleanup_sdk_tool_results
evil_cwd = str(tmp_path / "evil-path")
os.makedirs(evil_cwd, exist_ok=True)
with patch(
"backend.copilot.sdk.service.cleanup_stale_project_dirs"
) as mock_sweep:
await _cleanup_sdk_tool_results(evil_cwd)
# Directory should NOT have been removed (rejected early)
assert os.path.exists(evil_cwd)
mock_sweep.assert_not_called()

View File

@@ -146,7 +146,7 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
event.set()
async def wait_for_stash(timeout: float = 2.0) -> bool:
async def wait_for_stash(timeout: float = 0.5) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
@@ -155,12 +155,12 @@ async def wait_for_stash(timeout: float = 2.0) -> bool:
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
The 2.0 s default was chosen based on production metrics: the original
0.5 s caused frequent timeouts under load (parallel tool calls, large
outputs). 2.0 s gives a comfortable margin while still failing fast
when the hook genuinely will not fire.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net — normally the stash happens within
microseconds of yielding to the event loop.
"""
event = _stash_event.get(None)
if event is None:
@@ -234,7 +234,9 @@ def create_tool_handler(base_tool: BaseTool):
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
logger.error(
"Error executing tool %s: %s", base_tool.name, e, exc_info=True
)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
@@ -285,7 +287,7 @@ async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved, encoding="utf-8", errors="replace") as f:
with open(resolved) as f:
selected = list(itertools.islice(f, offset, offset + limit))
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.

View File

@@ -10,6 +10,9 @@ Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
@@ -17,8 +20,12 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
from backend.util import json
from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
logger = logging.getLogger(__name__)
@@ -99,7 +106,14 @@ def strip_progress_entries(content: str) -> str:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
# seen_parents is local per-entry (not shared across iterations) so
# it can only detect cycles within a single ancestry walk, not across
# entries. This is intentional: each entry's parent chain is
# independent, and reusing a global set would incorrectly short-circuit
# valid re-use of the same UUID as a parent in different subtrees.
seen_parents: set[str] = set()
while parent in stripped_uuids and parent not in seen_parents:
seen_parents.add(parent)
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
@@ -151,110 +165,44 @@ def _projects_base() -> str:
return os.path.realpath(os.path.join(config_dir, "projects"))
_STALE_PROJECT_DIR_SECONDS = 12 * 3600 # 12 hours — matches max session lifetime
_MAX_PROJECT_DIRS_TO_SWEEP = 50 # limit per sweep to avoid long pauses
def _cli_project_dir(sdk_cwd: str) -> str | None:
"""Return the CLI's project directory for a given working directory.
def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
"""Remove CLI project directories older than ``_STALE_PROJECT_DIR_SECONDS``.
Each CoPilot SDK turn creates a unique ``~/.claude/projects/<encoded-cwd>/``
directory. These are intentionally kept across turns so the model can read
tool-result files via ``--resume``. However, after a session ends they
become stale. This function sweeps old ones to prevent unbounded disk
growth.
When *encoded_cwd* is provided the sweep is scoped to that single
directory, making the operation safe in multi-tenant environments where
multiple copilot sessions share the same host. Without it the function
falls back to sweeping all directories matching the copilot naming pattern
(``-tmp-copilot-``), which is only safe for single-tenant deployments.
Returns the number of directories removed.
Returns ``None`` if the path would escape the projects base.
"""
cwd_encoded = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
projects_base = _projects_base()
if not os.path.isdir(projects_base):
return 0
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
now = time.time()
removed = 0
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
if "-tmp-copilot-" not in target.name:
logger.warning(
"[Transcript] Refusing to sweep non-copilot dir: %s", target.name
)
return 0
try:
# st_mtime is used as a proxy for session activity. Claude CLI writes
# its JSONL transcript into this directory during each turn, so mtime
# advances on every turn. A directory whose mtime is older than
# _STALE_PROJECT_DIR_SECONDS has not had an active turn in that window
# and is safe to remove (the session cannot --resume after cleanup).
age = now - target.stat().st_mtime
except OSError:
return 0
if age < _STALE_PROJECT_DIR_SECONDS:
return 0
try:
shutil.rmtree(target, ignore_errors=True)
removed = 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept stale CLI project dir %s (age %ds > %ds)",
target.name,
int(age),
_STALE_PROJECT_DIR_SECONDS,
)
return removed
# Unscoped fallback: sweep all copilot dirs across the projects base.
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
for entry in entries:
if removed >= _MAX_PROJECT_DIRS_TO_SWEEP:
break
# Only sweep copilot-generated dirs (pattern: -tmp-copilot- or
# -private-tmp-copilot-).
if "-tmp-copilot-" not in entry.name:
continue
if not entry.is_dir():
continue
try:
# See the scoped-mode comment above: st_mtime advances on every turn,
# so a stale mtime reliably indicates an inactive session.
age = now - entry.stat().st_mtime
except OSError:
continue
if age < _STALE_PROJECT_DIR_SECONDS:
continue
try:
shutil.rmtree(entry, ignore_errors=True)
removed += 1
except OSError:
pass
if removed:
logger.info(
"[Transcript] Swept %d stale CLI project dirs (older than %ds)",
removed,
_STALE_PROJECT_DIR_SECONDS,
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
"[Transcript] Project dir escaped projects base: %s", project_dir
)
return removed
return None
return project_dir
def _safe_glob_jsonl(project_dir: str) -> list[Path]:
"""Glob ``*.jsonl`` files, filtering out symlinks that escape the directory."""
try:
resolved_base = Path(project_dir).resolve()
except OSError as e:
logger.warning("[Transcript] Failed to resolve project dir: %s", e)
return []
result: list[Path] = []
for candidate in Path(project_dir).glob("*.jsonl"):
try:
resolved = candidate.resolve()
if resolved.is_relative_to(resolved_base):
result.append(resolved)
except (OSError, RuntimeError) as e:
logger.debug(
"[Transcript] Skipping invalid CLI session candidate %s: %s",
candidate,
e,
)
return result
def read_compacted_entries(transcript_path: str) -> list[dict] | None:
@@ -321,6 +269,63 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
return entries
def read_cli_session_file(sdk_cwd: str) -> str | None:
"""Read the CLI's own session file, which reflects any compaction.
The CLI writes its session transcript to
``~/.claude/projects/<encoded_cwd>/<session_id>.jsonl``.
Since each SDK turn uses a unique ``sdk_cwd``, there should be
exactly one ``.jsonl`` file in that directory.
Returns the file content, or ``None`` if not found.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir or not os.path.isdir(project_dir):
return None
jsonl_files = _safe_glob_jsonl(project_dir)
if not jsonl_files:
logger.debug("[Transcript] No CLI session file found in %s", project_dir)
return None
# Pick the most recently modified file (should be only one per turn).
try:
session_file = max(jsonl_files, key=lambda p: p.stat().st_mtime)
except OSError as e:
logger.warning("[Transcript] Failed to inspect CLI session files: %s", e)
return None
try:
content = session_file.read_text()
logger.info(
"[Transcript] Read CLI session file: %s (%d bytes)",
session_file,
len(content),
)
return content
except OSError as e:
logger.warning("[Transcript] Failed to read CLI session file: %s", e)
return None
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
project_dir = _cli_project_dir(sdk_cwd)
if not project_dir:
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug("[Transcript] Cleaned up CLI project dir: %s", project_dir)
else:
logger.debug("[Transcript] Project dir not found: %s", project_dir)
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
@@ -336,7 +341,7 @@ def write_transcript_to_tempfile(
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
logger.warning("[Transcript] cwd outside sandbox: %s", cwd)
return None
try:
@@ -346,17 +351,17 @@ def write_transcript_to_tempfile(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
logger.warning("[Transcript] Path escaped cwd: %s", jsonl_path)
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
logger.info("[Transcript] Wrote resume file: %s", jsonl_path)
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
logger.warning("[Transcript] Failed to write resume file: %s", e)
return None
@@ -417,8 +422,6 @@ def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, s
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = parts
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
@@ -457,17 +460,15 @@ async def upload_transcript(
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
"""
from backend.util.workspace_storage import get_workspace_storage
# Strip metadata entries (progress, file-history-snapshot, etc.)
# Note: SDK-built transcripts shouldn't have these, but strip for safety
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types: list[str] = []
for line in stripped.strip().split("\n"):
entry = json.loads(line, fallback={"type": "INVALID_JSON"})
entry_types.append(entry.get("type", "?"))
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
for line in stripped.strip().split("\n")
]
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
@@ -503,11 +504,14 @@ async def upload_transcript(
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"{log_prefix} Failed to write metadata: {e}")
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
logger.info(
f"{log_prefix} Uploaded {len(encoded)}B "
f"(stripped from {len(content)}B, msg_count={message_count})"
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
log_prefix,
len(encoded),
len(content),
message_count,
)
@@ -521,8 +525,6 @@ async def download_transcript(
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
@@ -530,10 +532,10 @@ async def download_transcript(
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
logger.debug(f"{log_prefix} No transcript in storage")
logger.debug("%s No transcript in storage", log_prefix)
return None
except Exception as e:
logger.warning(f"{log_prefix} Failed to download transcript: {e}")
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
return None
# Try to load metadata (best-effort — old transcripts won't have it)
@@ -545,10 +547,14 @@ async def download_transcript(
meta = json.loads(meta_data.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, Exception):
except FileNotFoundError:
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(f"{log_prefix} Downloaded {len(content)}B (msg_count={message_count})")
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
@@ -562,8 +568,6 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
@@ -580,3 +584,280 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
# ---------------------------------------------------------------------------
# Transcript compaction — LLM summarization for prompt-too-long recovery
# ---------------------------------------------------------------------------
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.
Structured ``tool_use`` blocks are converted to ``[tool_use: name]``
placeholders. This is intentional: ``compress_context`` requires plain
text for token counting and LLM summarization. The structural loss is
acceptable because compaction only runs when the original transcript was
already too large for the model — a summarized plain-text version is
better than no context at all.
"""
parts: list[str] = []
for block in blocks:
if isinstance(block, dict):
btype = block.get("type", "")
if btype == "text":
parts.append(block.get("text", ""))
elif btype == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
else:
# Preserve non-text blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
parts.append(f"[__{btype}__]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
def _flatten_tool_result_content(blocks: list) -> str:
"""Flatten tool_result and other content blocks into plain text.
Handles nested tool_result structures, text blocks, and raw strings.
Uses ``json.dumps`` as fallback for dict blocks without a ``text`` key
or where ``text`` is ``None``.
Like ``_flatten_assistant_content``, structured blocks (images, nested
tool results) are reduced to text representations for compression.
"""
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content") or ""
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
sub_type = sub.get("type")
if sub_type in ("image", "document"):
# Avoid serializing base64 binary data into
# the compaction input — use a placeholder.
str_parts.append(f"[__{sub_type}__]")
elif sub_type == "text" or sub.get("text") is not None:
str_parts.append(str(sub.get("text", "")))
else:
str_parts.append(json.dumps(sub))
else:
str_parts.append(str(sub))
else:
str_parts.append(str(inner))
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, dict):
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
btype = block.get("type", "unknown")
str_parts.append(f"[__{btype}__]")
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
def _transcript_to_messages(content: str) -> list[dict]:
"""Convert JSONL transcript entries to plain message dicts for compression.
Parses each line of the JSONL *content*, skips strippable metadata entries
(progress, file-history-snapshot, etc.), and extracts the ``role`` and
flattened ``content`` from the ``message`` field of each remaining entry.
Structured content blocks (``tool_use``, ``tool_result``, images) are
flattened to plain text via ``_flatten_assistant_content`` and
``_flatten_tool_result_content`` so that ``compress_context`` can
perform token counting and LLM summarization on uniform strings.
Returns:
A list of ``{"role": str, "content": str}`` dicts suitable for
``compress_context``.
"""
messages: list[dict] = []
for line in content.strip().split("\n"):
if not line.strip():
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type", "") in STRIPPABLE_TYPES and not entry.get(
"isCompactSummary"
):
continue
msg = entry.get("message", {})
role = msg.get("role", "")
if not role:
continue
msg_dict: dict = {"role": role}
raw_content = msg.get("content")
if role == "assistant" and isinstance(raw_content, list):
msg_dict["content"] = _flatten_assistant_content(raw_content)
elif isinstance(raw_content, list):
msg_dict["content"] = _flatten_tool_result_content(raw_content)
else:
msg_dict["content"] = raw_content or ""
messages.append(msg_dict)
return messages
def _messages_to_transcript(messages: list[dict]) -> str:
"""Convert compressed message dicts back to JSONL transcript format.
Rebuilds a minimal JSONL transcript from the ``{"role", "content"}``
dicts returned by ``compress_context``. Each message becomes one JSONL
line with a fresh ``uuid`` / ``parentUuid`` chain so the CLI's
``--resume`` flag can reconstruct a valid conversation tree.
Assistant messages are wrapped in the full ``message`` envelope
(``id``, ``model``, ``stop_reason``, structured ``content`` blocks)
that the CLI expects. User messages use the simpler ``{role, content}``
form.
Returns:
A newline-terminated JSONL string, or an empty string if *messages*
is empty.
"""
lines: list[str] = []
last_uuid: str = "" # root entry uses empty string, not null
for msg in messages:
role = msg.get("role", "user")
entry_type = "assistant" if role == "assistant" else "user"
uid = str(uuid4())
content = msg.get("content", "")
if role == "assistant":
message: dict = {
"role": "assistant",
"model": "",
"id": f"{COMPACT_MSG_ID_PREFIX}{uuid4().hex[:24]}",
"type": ENTRY_TYPE_MESSAGE,
"content": [{"type": "text", "text": content}] if content else [],
"stop_reason": STOP_REASON_END_TURN,
"stop_sequence": None,
}
else:
message = {"role": role, "content": content}
entry = {
"type": entry_type,
"uuid": uid,
"parentUuid": last_uuid,
"message": message,
}
lines.append(json.dumps(entry, separators=(",", ":")))
last_uuid = uid
return "\n".join(lines) + "\n" if lines else ""
_COMPACTION_TIMEOUT_SECONDS = 60
_TRUNCATION_TIMEOUT_SECONDS = 30
async def _run_compression(
messages: list[dict],
model: str,
log_prefix: str,
) -> CompressResult:
"""Run LLM-based compression with truncation fallback.
Uses the shared OpenAI client from ``get_openai_client()``.
If no client is configured or the LLM call fails, falls back to
truncation-based compression which drops older messages without
summarization.
A 60-second timeout prevents a hung LLM call from blocking the
retry path indefinitely. The truncation fallback also has a
30-second timeout to guard against slow tokenization on very large
transcripts.
"""
client = get_openai_client()
if client is None:
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
try:
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=client),
timeout=_COMPACTION_TIMEOUT_SECONDS,
)
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await asyncio.wait_for(
compress_context(messages=messages, model=model, client=None),
timeout=_TRUNCATION_TIMEOUT_SECONDS,
)
async def compact_transcript(
content: str,
*,
model: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
Converts transcript entries to plain messages, runs ``compress_context``
(the same compressor used for pre-query history), and rebuilds JSONL.
Structured content (``tool_use`` blocks, ``tool_result`` nesting, images)
is flattened to plain text for compression. This matches the fidelity of
the Plan C (DB compression) fallback path, where
``_format_conversation_context`` similarly renders tool calls as
``You called tool: name(args)`` and results as ``Tool result: ...``.
Neither path preserves structured API content blocks — the compacted
context serves as text history for the LLM, which creates proper
structured tool calls going forward.
Images are per-turn attachments loaded from workspace storage by file ID
(via ``_prepare_file_attachments``), not part of the conversation history.
They are re-attached each turn and are unaffected by compaction.
Returns the compacted JSONL string, or ``None`` on failure.
See also:
``_compress_messages`` in ``service.py`` — compresses ``ChatMessage``
lists for pre-query DB history. Both share ``compress_context()``
but operate on different input formats (JSONL transcript entries
here vs. ChatMessage dicts there).
"""
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))
return None
try:
result = await _run_compression(messages, model, log_prefix)
if not result.was_compacted:
# Compressor says it's within budget, but the SDK rejected it.
# Return None so the caller falls through to DB fallback.
logger.warning(
"%s Compressor reports within budget but SDK rejected — "
"signalling failure",
log_prefix,
)
return None
logger.info(
"%s Compacted transcript: %d->%d tokens (%d summarized, %d dropped)",
log_prefix,
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
compacted = _messages_to_transcript(result.messages)
if not validate_transcript(compacted):
logger.warning("%s Compacted transcript failed validation", log_prefix)
return None
return compacted
except Exception as e:
logger.error(
"%s Transcript compaction failed: %s", log_prefix, e, exc_info=True
)
return None

View File

@@ -68,7 +68,7 @@ class TranscriptBuilder:
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
isCompactSummary=data.get("isCompactSummary") or None,
isCompactSummary=data.get("isCompactSummary"),
message=data.get("message", {}),
)

View File

@@ -1,7 +1,7 @@
"""Unit tests for JSONL transcript management utilities."""
import os
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -9,7 +9,9 @@ from backend.util import json
from .transcript import (
STRIPPABLE_TYPES,
_cli_project_dir,
delete_transcript,
read_cli_session_file,
read_compacted_entries,
strip_progress_entries,
validate_transcript,
@@ -290,6 +292,85 @@ class TestStripProgressEntries:
assert asst_entry["parentUuid"] == "u1" # reparented
# --- read_cli_session_file ---
class TestReadCliSessionFile:
def test_no_matching_files_returns_none(self, tmp_path, monkeypatch):
"""read_cli_session_file returns None when no .jsonl files exist."""
# Create a project dir with no jsonl files
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
assert read_cli_session_file("/fake/cwd") is None
def test_one_jsonl_file_returns_content(self, tmp_path, monkeypatch):
"""read_cli_session_file returns the content of a single .jsonl file."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
jsonl_file = project_dir / "session.jsonl"
jsonl_file.write_text("line1\nline2\n")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
result = read_cli_session_file("/fake/cwd")
assert result == "line1\nline2\n"
def test_symlink_escaping_project_dir_is_skipped(self, tmp_path, monkeypatch):
"""read_cli_session_file skips symlinks that escape the project dir."""
project_dir = tmp_path / "projects" / "encoded-cwd"
project_dir.mkdir(parents=True)
# Create a file outside the project dir
outside = tmp_path / "outside"
outside.mkdir()
outside_file = outside / "evil.jsonl"
outside_file.write_text("should not be read\n")
# Symlink from inside project_dir to outside file
symlink = project_dir / "evil.jsonl"
symlink.symlink_to(outside_file)
monkeypatch.setattr(
"backend.copilot.sdk.transcript._cli_project_dir",
lambda sdk_cwd: str(project_dir),
)
# The symlink target resolves outside project_dir, so it should be skipped
result = read_cli_session_file("/fake/cwd")
assert result is None
# --- _cli_project_dir ---
class TestCliProjectDir:
def test_returns_none_for_path_traversal(self, tmp_path, monkeypatch):
"""_cli_project_dir returns None when the project dir symlink escapes projects base."""
config_dir = tmp_path / "config"
config_dir.mkdir()
projects_dir = config_dir / "projects"
projects_dir.mkdir()
monkeypatch.setenv("CLAUDE_CONFIG_DIR", str(config_dir))
# Create a symlink inside projects/ that points outside of it.
# _cli_project_dir encodes the cwd as all-alnum-hyphens, so use a
# cwd whose encoded form matches the symlink name we create.
evil_target = tmp_path / "escaped"
evil_target.mkdir()
# The encoded form of "/evil/cwd" is "-evil-cwd"
symlink_path = projects_dir / "-evil-cwd"
symlink_path.symlink_to(evil_target)
result = _cli_project_dir("/evil/cwd")
assert result is None
# --- delete_transcript ---
@@ -301,7 +382,7 @@ class TestDeleteTranscript:
mock_storage.delete = AsyncMock()
with patch(
"backend.util.workspace_storage.get_workspace_storage",
"backend.copilot.sdk.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -321,7 +402,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.util.workspace_storage.get_workspace_storage",
"backend.copilot.sdk.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -339,7 +420,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.util.workspace_storage.get_workspace_storage",
"backend.copilot.sdk.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -819,206 +900,131 @@ class TestCompactionFlowIntegration:
# ---------------------------------------------------------------------------
# cleanup_stale_project_dirs
# _run_compression (direct tests for the 3 code paths)
# ---------------------------------------------------------------------------
class TestCleanupStaleProjectDirs:
"""Tests for cleanup_stale_project_dirs (disk leak prevention)."""
class TestRunCompression:
"""Direct tests for ``_run_compression`` covering all 3 code paths.
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories matching copilot pattern older than threshold are removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
Paths:
(a) No OpenAI client configured → truncation fallback immediately.
(b) LLM success → returns LLM-compressed result.
(c) LLM call raises → truncation fallback.
"""
def _make_compress_result(self, was_compacted: bool, msgs=None):
"""Build a minimal CompressResult-like object."""
from types import SimpleNamespace
return SimpleNamespace(
was_compacted=was_compacted,
messages=msgs or [{"role": "user", "content": "summary"}],
original_token_count=500,
token_count=100 if was_compacted else 500,
messages_summarized=2 if was_compacted else 0,
messages_dropped=0,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
@pytest.mark.asyncio
async def test_no_client_uses_truncation(self):
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
from .transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated"}]
)
# Create a stale dir
stale = projects_dir / "-tmp-copilot-old-session"
stale.mkdir()
# Set mtime to past the threshold
import time
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=None,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
new_callable=AsyncMock,
return_value=truncation_result,
) as mock_compress,
):
result = await _run_compression(
[{"role": "user", "content": "hello"}],
model="test-model",
log_prefix="[test]",
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale, (old_time, old_time))
# Create a fresh dir
fresh = projects_dir / "-tmp-copilot-new-session"
fresh.mkdir()
removed = cleanup_stale_project_dirs()
assert removed == 1
assert not stale.exists()
assert fresh.exists()
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories not matching copilot pattern are left alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
# compress_context called with client=None (truncation mode)
call_kwargs = mock_compress.call_args
assert (
call_kwargs.kwargs.get("client") is None
or (call_kwargs.args and call_kwargs.args[2] is None)
or mock_compress.call_args[1].get("client") is None
)
assert result is truncation_result
# Non-copilot dir that's old
import time
@pytest.mark.asyncio
async def test_llm_success_returns_llm_result(self):
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
from .transcript import _run_compression
other = projects_dir / "some-other-project"
other.mkdir()
old_time = time.time() - 999999
os.utime(other, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert other.exists()
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
"""A directory exactly at the TTL boundary should NOT be removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
llm_result = self._make_compress_result(
True, [{"role": "user", "content": "LLM summary"}]
)
mock_client = MagicMock()
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
new_callable=AsyncMock,
return_value=llm_result,
) as mock_compress,
):
result = await _run_compression(
[{"role": "user", "content": "long conversation"}],
model="test-model",
log_prefix="[test]",
)
# compress_context called with the real client
assert mock_compress.called
assert result is llm_result
@pytest.mark.asyncio
async def test_llm_failure_falls_back_to_truncation(self):
"""Path (c): LLM call raises → truncation fallback used instead."""
from .transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated fallback"}]
)
mock_client = MagicMock()
call_count = [0]
import time
async def _compress_side_effect(**kwargs):
call_count[0] += 1
if kwargs.get("client") is not None:
raise RuntimeError("LLM timeout")
return truncation_result
# Dir that's exactly at the TTL (age == threshold, not >) — should survive
boundary = projects_dir / "-tmp-copilot-boundary"
boundary.mkdir()
boundary_time = time.time() - _STALE_PROJECT_DIR_SECONDS + 1
os.utime(boundary, (boundary_time, boundary_time))
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
side_effect=_compress_side_effect,
),
):
result = await _run_compression(
[{"role": "user", "content": "long conversation"}],
model="test-model",
log_prefix="[test]",
)
removed = cleanup_stale_project_dirs()
assert removed == 0
assert boundary.exists()
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
"""Regular files matching the copilot pattern are not removed."""
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
import time
# Create a regular FILE (not a dir) with the copilot pattern name
stale_file = projects_dir / "-tmp-copilot-stale-file"
stale_file.write_text("not a dir")
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
os.utime(stale_file, (old_time, old_time))
removed = cleanup_stale_project_dirs()
assert removed == 0
assert stale_file.exists()
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
"""If the projects base directory doesn't exist, return 0 gracefully."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: nonexistent,
)
removed = cleanup_stale_project_dirs()
assert removed == 0
def test_scoped_removes_only_target_dir(self, tmp_path, monkeypatch):
"""When encoded_cwd is supplied only that directory is swept."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
# Two stale copilot dirs
target = projects_dir / "-tmp-copilot-session-abc"
target.mkdir()
os.utime(target, (old_time, old_time))
other = projects_dir / "-tmp-copilot-session-xyz"
other.mkdir()
os.utime(other, (old_time, old_time))
# Only the target dir should be removed
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-abc")
assert removed == 1
assert not target.exists()
assert other.exists() # untouched — not the current session
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep leaves a fresh directory alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
fresh = projects_dir / "-tmp-copilot-session-new"
fresh.mkdir()
# mtime is now — well within TTL
removed = cleanup_stale_project_dirs(encoded_cwd="-tmp-copilot-session-new")
assert removed == 0
assert fresh.exists()
def test_scoped_non_copilot_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep refuses to remove a non-copilot directory."""
import time
from backend.copilot.sdk.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
lambda: str(projects_dir),
)
old_time = time.time() - _STALE_PROJECT_DIR_SECONDS - 100
non_copilot = projects_dir / "some-other-project"
non_copilot.mkdir()
os.utime(non_copilot, (old_time, old_time))
removed = cleanup_stale_project_dirs(encoded_cwd="some-other-project")
assert removed == 0
assert non_copilot.exists()
# compress_context called twice: once for LLM (raises), once for truncation
assert call_count[0] == 2
assert result is truncation_result

View File

@@ -41,8 +41,7 @@ import contextlib
import logging
from typing import Any, Awaitable, Callable, Literal
from e2b import AsyncSandbox
from e2b.sandbox.sandbox_api import SandboxLifecycle
from e2b import AsyncSandbox, SandboxLifecycle
from backend.data.redis_client import get_redis_async

View File

@@ -2,7 +2,6 @@
import base64
import logging
import mimetypes
import os
from typing import Any, Optional
@@ -11,9 +10,7 @@ from pydantic import BaseModel
from backend.copilot.context import (
E2B_WORKDIR,
get_current_sandbox,
get_sdk_cwd,
get_workspace_manager,
is_allowed_local_path,
resolve_sandbox_path,
)
from backend.copilot.model import ChatSession
@@ -27,10 +24,6 @@ from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
# Sentinel file_id used when a tool-result file is read directly from the local
# host filesystem (rather than from workspace storage).
_LOCAL_TOOL_RESULT_FILE_ID = "local"
async def _resolve_write_content(
content_text: str | None,
@@ -282,93 +275,6 @@ class WorkspaceFileContentResponse(ToolResponseBase):
content_base64: str
_MAX_LOCAL_TOOL_RESULT_BYTES = 10 * 1024 * 1024 # 10 MB
def _read_local_tool_result(
path: str,
char_offset: int,
char_length: Optional[int],
session_id: str,
sdk_cwd: str | None = None,
) -> ToolResponseBase:
"""Read an SDK tool-result file from local disk.
This is a fallback for when the model mistakenly calls
``read_workspace_file`` with an SDK tool-result path that only exists on
the host filesystem, not in cloud workspace storage.
Defence-in-depth: validates *path* via :func:`is_allowed_local_path`
regardless of what the caller has already checked.
"""
# TOCTOU: path validated then opened separately. Acceptable because
# the tool-results directory is server-controlled, not user-writable.
expanded = os.path.realpath(os.path.expanduser(path))
# Defence-in-depth: re-check with resolved path (caller checked raw path).
if not is_allowed_local_path(expanded, sdk_cwd or get_sdk_cwd()):
return ErrorResponse(
message=f"Path not allowed: {os.path.basename(path)}", session_id=session_id
)
try:
# The 10 MB cap (_MAX_LOCAL_TOOL_RESULT_BYTES) bounds memory usage.
# Pre-read size check prevents loading files far above the cap;
# the remaining TOCTOU gap is acceptable for server-controlled paths.
file_size = os.path.getsize(expanded)
if file_size > _MAX_LOCAL_TOOL_RESULT_BYTES:
return ErrorResponse(
message=(f"File too large: {os.path.basename(path)}"),
session_id=session_id,
)
# Detect binary files: try strict UTF-8 first, fall back to
# base64-encoding the raw bytes for binary content.
with open(expanded, "rb") as fh:
raw = fh.read()
try:
text_content = raw.decode("utf-8")
except UnicodeDecodeError:
# Binary file — return raw base64, ignore char_offset/char_length
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "application/octet-stream",
content_base64=base64.b64encode(raw).decode("ascii"),
message=(
f"Read {file_size:,} bytes (binary) from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
end = (
char_offset + char_length if char_length is not None else len(text_content)
)
slice_text = text_content[char_offset:end]
except FileNotFoundError:
return ErrorResponse(
message=f"File not found: {os.path.basename(path)}", session_id=session_id
)
except Exception as exc:
return ErrorResponse(
message=f"Error reading file: {type(exc).__name__}", session_id=session_id
)
return WorkspaceFileContentResponse(
file_id=_LOCAL_TOOL_RESULT_FILE_ID,
name=os.path.basename(path),
path=path,
mime_type=mimetypes.guess_type(path)[0] or "text/plain",
content_base64=base64.b64encode(slice_text.encode("utf-8")).decode("ascii"),
message=(
f"Read chars {char_offset}\u2013{char_offset + len(slice_text)} "
f"of {len(text_content):,} chars from local tool-result "
f"{os.path.basename(path)}"
),
session_id=session_id,
)
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
@@ -627,14 +533,6 @@ class ReadWorkspaceFileTool(BaseTool):
manager = await get_workspace_manager(user_id, session_id)
resolved = await _resolve_file(manager, file_id, path, session_id)
if isinstance(resolved, ErrorResponse):
# Fallback: if the path is an SDK tool-result on local disk,
# read it directly instead of failing. The model sometimes
# calls read_workspace_file for these paths by mistake.
sdk_cwd = get_sdk_cwd()
if path and is_allowed_local_path(path, sdk_cwd):
return _read_local_tool_result(
path, char_offset, char_length, session_id, sdk_cwd=sdk_cwd
)
return resolved
target_file_id, file_info = resolved

View File

@@ -2,25 +2,18 @@
import base64
import os
import shutil
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.context import SDK_PROJECTS_DIR, _current_project_dir
from backend.copilot.tools._test_data import make_session, setup_test_data
from backend.copilot.tools.models import ErrorResponse
from backend.copilot.tools.workspace_files import (
_MAX_LOCAL_TOOL_RESULT_BYTES,
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
ReadWorkspaceFileTool,
WorkspaceDeleteResponse,
WorkspaceFileContentResponse,
WorkspaceFileListResponse,
WorkspaceWriteResponse,
WriteWorkspaceFileTool,
_read_local_tool_result,
_resolve_write_content,
_validate_ephemeral_path,
)
@@ -332,294 +325,3 @@ async def test_write_workspace_file_source_path(setup_test_data):
await delete_tool._execute(
user_id=user.id, session=session, file_id=write_resp.file_id
)
# ---------------------------------------------------------------------------
# _read_local_tool_result — local disk fallback for SDK tool-result files
# ---------------------------------------------------------------------------
_CONV_UUID = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
class TestReadLocalToolResult:
"""Tests for _read_local_tool_result (local disk fallback)."""
def _make_tool_result(self, encoded: str, filename: str, content: bytes) -> str:
"""Create a tool-results file and return its path."""
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, filename)
with open(filepath, "wb") as f:
f.write(content)
return filepath
def _cleanup(self, encoded: str) -> None:
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
def test_read_text_file(self):
"""Read a UTF-8 text tool-result file."""
encoded = "-tmp-copilot-local-read-text"
path = self._make_tool_result(encoded, "output.txt", b"hello world")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "hello world"
assert "text/plain" in result.mime_type
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_text_with_offset(self):
"""Read a slice of a text file using char_offset and char_length."""
encoded = "-tmp-copilot-local-read-offset"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEFGHIJ")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 3, 4, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "DEFG"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_read_binary_file(self):
"""Binary files are returned as raw base64."""
encoded = "-tmp-copilot-local-read-binary"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "image.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64)
assert decoded == binary_data
assert "binary" in result.message
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_disallowed_path_rejected(self):
"""Paths not under allowed directories are rejected."""
result = _read_local_tool_result("/etc/passwd", 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not allowed" in result.message.lower()
def test_file_not_found(self):
"""Missing files return an error."""
encoded = "-tmp-copilot-local-read-missing"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, _CONV_UUID, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
path = os.path.join(tool_dir, "nope.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "not found" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_file_too_large(self, monkeypatch):
"""Files exceeding the size limit are rejected."""
encoded = "-tmp-copilot-local-read-large"
# Create a small file but fake os.path.getsize to return a huge value
path = self._make_tool_result(encoded, "big.txt", b"small")
token = _current_project_dir.set(encoded)
monkeypatch.setattr(
"os.path.getsize", lambda _: _MAX_LOCAL_TOOL_RESULT_BYTES + 1
)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, ErrorResponse)
assert "too large" in result.message.lower()
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_beyond_file_length(self):
"""Offset past end-of-file returns empty content."""
encoded = "-tmp-copilot-local-read-past-eof"
path = self._make_tool_result(encoded, "short.txt", b"abc")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 999, 10, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_zero_length_read(self):
"""Requesting zero characters returns empty content."""
encoded = "-tmp-copilot-local-read-zero-len"
path = self._make_tool_result(encoded, "data.txt", b"ABCDEF")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 2, 0, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == ""
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_json_extension(self):
"""JSON files get application/json MIME type, not hardcoded text/plain."""
encoded = "-tmp-copilot-local-read-json"
path = self._make_tool_result(encoded, "result.json", b'{"key": "value"}')
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "application/json"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_mime_type_from_png_extension(self):
"""Binary .png files get image/png MIME type via mimetypes."""
encoded = "-tmp-copilot-local-read-png-mime"
binary_data = bytes(range(256))
path = self._make_tool_result(encoded, "chart.png", binary_data)
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 0, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
assert result.mime_type == "image/png"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_explicit_sdk_cwd_parameter(self):
"""The sdk_cwd parameter overrides get_sdk_cwd() for path validation."""
encoded = "-tmp-copilot-local-read-sdkcwd"
path = self._make_tool_result(encoded, "out.txt", b"content")
token = _current_project_dir.set(encoded)
try:
# Pass sdk_cwd explicitly — should still succeed because the path
# is under SDK_PROJECTS_DIR which is always allowed.
result = _read_local_tool_result(
path, 0, None, "s1", sdk_cwd="/tmp/copilot-test"
)
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "content"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
def test_offset_with_no_length_reads_to_end(self):
"""When char_length is None, read from offset to end of file."""
encoded = "-tmp-copilot-local-read-offset-noLen"
path = self._make_tool_result(encoded, "data.txt", b"0123456789")
token = _current_project_dir.set(encoded)
try:
result = _read_local_tool_result(path, 5, None, "s1")
assert isinstance(result, WorkspaceFileContentResponse)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "56789"
finally:
_current_project_dir.reset(token)
self._cleanup(encoded)
# ---------------------------------------------------------------------------
# ReadWorkspaceFileTool fallback to _read_local_tool_result
# ---------------------------------------------------------------------------
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_falls_back_to_local_tool_result(setup_test_data):
"""When _resolve_file returns ErrorResponse for an allowed local path,
ReadWorkspaceFileTool should fall back to _read_local_tool_result."""
user = setup_test_data["user"]
session = make_session(user.id)
# Create a real tool-result file on disk so the fallback can read it.
encoded = "-tmp-copilot-fallback-test"
conv_uuid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
tool_dir = os.path.join(SDK_PROJECTS_DIR, encoded, conv_uuid, "tool-results")
os.makedirs(tool_dir, exist_ok=True)
filepath = os.path.join(tool_dir, "result.txt")
with open(filepath, "w") as f:
f.write("fallback content")
token = _current_project_dir.set(encoded)
try:
# Mock _resolve_file to return an ErrorResponse (simulating "file not
# found in workspace") so the fallback branch is exercised.
mock_resolve = AsyncMock(
return_value=ErrorResponse(
message="File not found at path: result.txt",
session_id=session.session_id,
)
)
with patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
path=filepath,
)
# Should have fallen back to _read_local_tool_result and succeeded.
assert isinstance(result, WorkspaceFileContentResponse), (
f"Expected fallback to local read, got {type(result).__name__}: "
f"{getattr(result, 'message', '')}"
)
decoded = base64.b64decode(result.content_base64).decode("utf-8")
assert decoded == "fallback content"
mock_resolve.assert_awaited_once()
finally:
_current_project_dir.reset(token)
shutil.rmtree(os.path.join(SDK_PROJECTS_DIR, encoded), ignore_errors=True)
@pytest.mark.asyncio(loop_scope="session")
async def test_read_workspace_file_no_fallback_when_resolve_succeeds(setup_test_data):
"""When _resolve_file succeeds, the local-disk fallback must NOT be invoked."""
user = setup_test_data["user"]
session = make_session(user.id)
fake_file_id = "fake-file-id-001"
fake_content = b"workspace content"
# Build a minimal file_info stub that the tool's happy-path needs.
class _FakeFileInfo:
id = fake_file_id
name = "result.json"
path = "/result.json"
mime_type = "text/plain"
size_bytes = len(fake_content)
mock_resolve = AsyncMock(return_value=(fake_file_id, _FakeFileInfo()))
mock_manager = AsyncMock()
mock_manager.read_file_by_id = AsyncMock(return_value=fake_content)
with (
patch("backend.copilot.tools.workspace_files._resolve_file", mock_resolve),
patch(
"backend.copilot.tools.workspace_files.get_workspace_manager",
AsyncMock(return_value=mock_manager),
),
patch(
"backend.copilot.tools.workspace_files._read_local_tool_result"
) as patched_local,
):
read_tool = ReadWorkspaceFileTool()
result = await read_tool._execute(
user_id=user.id,
session=session,
file_id=fake_file_id,
)
# Fallback must not have been called.
patched_local.assert_not_called()
# Normal workspace path must have produced a content response.
assert isinstance(result, WorkspaceFileContentResponse)
assert base64.b64decode(result.content_base64) == fake_content

View File

@@ -70,6 +70,10 @@ def _msg_tokens(msg: dict, enc) -> int:
# Count tool result tokens
tool_call_tokens += _tok_len(item.get("tool_use_id", ""), enc)
tool_call_tokens += _tok_len(item.get("content", ""), enc)
elif isinstance(item, dict) and item.get("type") == "text":
# Count text block tokens (standard: "text" key, fallback: "content")
text_val = item.get("text") or item.get("content", "")
tool_call_tokens += _tok_len(text_val, enc)
elif isinstance(item, dict) and "content" in item:
# Other content types with content field
tool_call_tokens += _tok_len(item.get("content", ""), enc)
@@ -145,10 +149,16 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
if len(ids) <= max_tok:
return text # nothing to do
# Need at least 3 tokens (head + ellipsis + tail) for meaningful truncation
if max_tok < 1:
return ""
mid = enc.encode("")
if max_tok < 3:
return enc.decode(ids[:max_tok])
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
@@ -545,6 +555,14 @@ async def _summarize_messages_llm(
"- Actions taken and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"IMPORTANT: Preserve all concrete references verbatim — these are small but "
"critical for continuing the conversation:\n"
"- File paths and directory paths (e.g. /src/app/page.tsx, ./output/result.csv)\n"
"- Image/media file paths from tool outputs\n"
"- URLs, API endpoints, and webhook addresses\n"
"- Resource IDs, session IDs, and identifiers\n"
"- Tool names that were called and their key parameters\n"
"- Environment variables, config keys, and credentials names (not values)\n\n"
"Include ONLY the sections below that have relevant content "
"(skip sections with nothing to report):\n\n"
"## 1. Primary Request and Intent\n"
@@ -552,7 +570,8 @@ async def _summarize_messages_llm(
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"Specific files examined or modified, with relevant snippets and identifiers. "
"Include exact file paths, image paths from tool outputs, and resource URLs.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions.\n\n"
"## 5. All User Messages\n"
@@ -566,7 +585,7 @@ async def _summarize_messages_llm(
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=1500,
max_tokens=2000,
temperature=0.3,
)
@@ -686,11 +705,15 @@ async def compress_context(
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
"Context summarized: %d -> %d tokens, summarized %d messages",
original_count,
total_tokens(),
messages_summarized,
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
logger.warning(
"Summarization failed, continuing with truncation: %s", e
)
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------