chore: merge master into dev, resolve baseline/transcript conflicts

Conflicts in baseline/service.py, baseline/transcript_integration_test.py,
and transcript.py arose because dev-only commit 0cd0a76305
(baseline upload fix) overlapped with the same fix in PR #12804 which
landed in master. Took master's version for all three files — it is the
complete, reviewed implementation.
This commit is contained in:
Zamil Majdy
2026-04-16 15:38:46 +07:00
17 changed files with 2380 additions and 1068 deletions

View File

@@ -67,11 +67,15 @@ from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
detect_gap,
download_transcript,
extract_context_messages,
strip_for_upload,
upload_transcript,
validate_transcript,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util import json as util_json
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
@@ -699,29 +703,7 @@ async def _compress_session_messages(
return messages
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
"""Return ``True`` when a download doesn't cover the current session.
A transcript is stale when it has a known ``message_count`` and that
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
already advanced beyond what the stored transcript captures).
Loading a stale transcript would silently drop intermediate turns,
so callers should treat stale as "skip load, skip upload".
An unknown ``message_count`` (``0``) is treated as **not stale**
because older transcripts uploaded before msg_count tracking
existed must still be usable.
"""
if dl is None:
return False
if not dl.message_count:
return False
return dl.message_count < session_msg_count - 1
def should_upload_transcript(
user_id: str | None, upload_safe: bool
) -> bool:
def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool:
"""Return ``True`` when the caller should upload the final transcript.
Uploads require a logged-in user (for the storage key) *and* a safe
@@ -731,55 +713,137 @@ def should_upload_transcript(
return bool(user_id) and upload_safe
def _append_gap_to_builder(
gap: list[ChatMessage],
builder: TranscriptBuilder,
) -> None:
"""Append gap messages from chat-db into the TranscriptBuilder.
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
Pre-condition: ``gap`` always starts at a user or assistant boundary
(never mid-turn at a ``tool`` role), because ``detect_gap`` enforces
``session_messages[wm-1].role == 'assistant'`` before returning a non-empty
gap. Any ``tool`` role messages within the gap always follow an assistant
entry that already exists in the builder or in the gap itself.
"""
for msg in gap:
if msg.role == "user":
builder.append_user(msg.content or "")
elif msg.role == "assistant":
content_blocks: list[dict] = []
if msg.content:
content_blocks.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
input_data = util_json.loads(fn.get("arguments", "{}"), fallback={})
content_blocks.append(
{
"type": "tool_use",
"id": tc.get("id", "") if isinstance(tc, dict) else "",
"name": fn.get("name", "unknown"),
"input": input_data,
}
)
if not content_blocks:
# Fallback: ensure every assistant gap message produces an entry
# so the builder's entry count matches the gap length.
content_blocks.append({"type": "text", "text": ""})
builder.append_assistant(content_blocks=content_blocks)
elif msg.role == "tool":
if msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
else:
# Malformed tool message — no tool_call_id to link to an
# assistant tool_use block. Skip to avoid an unmatched
# tool_result entry in the builder (which would confuse --resume).
logger.warning(
"[Baseline] Skipping tool gap message with no tool_call_id"
)
async def _load_prior_transcript(
user_id: str,
session_id: str,
session_msg_count: int,
session_messages: list[ChatMessage],
transcript_builder: TranscriptBuilder,
) -> bool:
"""Download and load the prior transcript into ``transcript_builder``.
) -> tuple[bool, "TranscriptDownload | None"]:
"""Download and load the prior CLI session into ``transcript_builder``.
Returns ``True`` when upload is safe at the end of this turn; ``False``
when GCS has a *newer* version that we must not overwrite (stale case).
Upload is suppressed only for **stale** transcripts (GCS watermark >
current turn's prefix) and **download errors** (we can't know what GCS
holds). Missing and invalid transcripts return ``True`` because there is
nothing in GCS worth protecting — uploading is always safe.
Returns a tuple of (upload_safe, transcript_download):
- ``upload_safe`` is ``True`` when it is safe to upload at the end of this
turn. Upload is suppressed only for **download errors** (unknown GCS
state) — missing and invalid files return ``True`` because there is
nothing in GCS worth protecting against overwriting.
- ``transcript_download`` is a ``TranscriptDownload`` with str content
(pre-decoded and stripped) when available, or ``None`` when no valid
transcript could be loaded. Callers pass this to
``extract_context_messages`` to build the LLM context.
"""
try:
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
except Exception as e:
logger.warning("[Baseline] Transcript download failed: %s", e)
# Unknown GCS state — be conservative and skip upload.
return False
if dl is None:
logger.debug("[Baseline] No transcript available — will upload fresh")
# Nothing in GCS to protect; allow upload.
return True
if not validate_transcript(dl.content):
logger.warning("[Baseline] Downloaded transcript is invalid — will overwrite")
# Corrupt file in GCS; uploading a valid one is strictly better.
return True
if is_transcript_stale(dl, session_msg_count):
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
session_msg_count,
restore = await download_transcript(
user_id, session_id, log_prefix="[Baseline]"
)
# GCS watermark is ahead of this turn — do not overwrite.
return False
except Exception as e:
logger.warning("[Baseline] Session restore failed: %s", e)
# Unknown GCS state — be conservative, skip upload.
return False, None
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
if restore is None:
logger.debug("[Baseline] No CLI session available — will upload fresh")
# Nothing in GCS to protect; allow upload so the first baseline turn
# writes the initial transcript snapshot.
return True, None
content_bytes = restore.content
try:
raw_str = (
content_bytes.decode("utf-8")
if isinstance(content_bytes, bytes)
else content_bytes
)
except UnicodeDecodeError:
logger.warning("[Baseline] CLI session content is not valid UTF-8")
# Corrupt file in GCS; overwriting with a valid one is better.
return True, None
stripped = strip_for_upload(raw_str)
if not validate_transcript(stripped):
logger.warning("[Baseline] CLI session content invalid after strip")
# Corrupt file in GCS; overwriting with a valid one is better.
return True, None
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
logger.info(
"[Baseline] Loaded transcript: %dB, msg_count=%d",
len(dl.content),
dl.message_count,
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str),
restore.message_count,
)
return True
gap = detect_gap(restore, session_messages)
if gap:
_append_gap_to_builder(gap, transcript_builder)
logger.info(
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
restore.message_count,
len(gap),
)
# Return a str-content version so extract_context_messages receives a
# pre-decoded, stripped transcript (avoids redundant decode + strip).
# TranscriptDownload.content is typed as bytes | str; we pass str here
# to avoid a redundant encode + decode round-trip.
str_restore = TranscriptDownload(
content=stripped,
message_count=restore.message_count,
mode=restore.mode,
)
return True, str_restore
async def _upload_final_transcript(
@@ -813,10 +877,10 @@ async def _upload_final_transcript(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
content=content.encode("utf-8"),
message_count=session_msg_count,
mode="baseline",
log_prefix="[Baseline]",
skip_strip=True,
)
)
_background_tasks.add(upload_task)
@@ -920,15 +984,16 @@ async def stream_chat_completion_baseline(
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path.
transcript_download: TranscriptDownload | None = None
if user_id and len(session.messages) > 1:
(
transcript_upload_safe,
(transcript_upload_safe, transcript_download),
(base_system_prompt, understanding),
) = await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
session_messages=session.messages,
transcript_builder=transcript_builder,
),
prompt_task,
@@ -968,9 +1033,14 @@ async def stream_chat_completion_baseline(
warm_ctx = await fetch_warm_context(user_id, message or "")
# Compress context if approaching the model's token limit
# Context path: transcript content (compacted, isCompactSummary preserved) +
# gap (DB messages after watermark) + current user turn.
# This avoids re-reading the full session history from DB on every turn.
# See extract_context_messages() in transcript.py for the shared primitive.
prior_context = extract_context_messages(transcript_download, session.messages)
messages_for_context = await _compress_session_messages(
session.messages, model=active_model
prior_context + ([session.messages[-1]] if session.messages else []),
model=active_model,
)
# Build OpenAI message list from session history.

View File

@@ -1,7 +1,7 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Exercises the real helpers in ``baseline/service.py`` that restore,
validate, load, append to, backfill, and upload the CLI session.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_append_gap_to_builder,
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.model import ChatMessage
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
@@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str:
return "\n".join(lines) + "\n"
def _make_session_messages(*roles: str) -> list[ChatMessage]:
"""Build a list of ChatMessage objects matching the given roles."""
return [
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
]
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
@@ -73,89 +81,102 @@ class TestResolveBaselineModel:
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="sdk"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert dl.message_count == 2
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
async def test_fills_gap_when_transcript_is_behind(self):
"""When transcript covers fewer messages than session, gap is filled from DB."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
restore = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="baseline"
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
assert covers is True
assert dl is not None
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_missing_transcript_allows_upload(self):
"""Nothing in GCS → safe to upload fresh transcript after the turn."""
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
upload_safe = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_allows_upload(self):
"""Corrupt file in GCS → overwriting with valid data is better."""
"""Corrupt file in GCS → overwriting with a valid one is better."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
restore = TranscriptDownload(
content=b'{"type":"progress","uuid":"a"}\n',
message_count=1,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
upload_safe = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert upload_safe is True
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
@@ -165,36 +186,39 @@ class TestLoadPriorTranscript:
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
assert covers is False
assert dl is None
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
"""When msg_count is 0 (unknown), gap detection is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
restore = TranscriptDownload(
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=0,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
session_messages=_make_session_messages(*["user"] * 20),
transcript_builder=builder,
)
assert covers is True
assert dl is not None
assert builder.entry_count == 2
@@ -229,7 +253,7 @@ class TestUploadFinalTranscript:
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
assert b"hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
@@ -376,17 +400,19 @@ class TestRoundTrip:
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -426,11 +452,11 @@ class TestRoundTrip:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
assert b"new question" in uploaded
assert b"new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
@@ -461,36 +487,6 @@ class TestRoundTrip:
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
@@ -512,7 +508,7 @@ class TestShouldUploadTranscript:
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
"""End-to-end: restore → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
@@ -521,27 +517,29 @@ class TestTranscriptLifecycle:
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
"""Fresh restore, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
restore = TranscriptDownload(
content=prior.encode("utf-8"), message_count=2, mode="sdk"
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
# --- 1. Restore & load prior session ---
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -561,10 +559,7 @@ class TestTranscriptLifecycle:
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", upload_safe=covers
)
is True
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
)
await _upload_final_transcript(
user_id="user-1",
@@ -576,20 +571,21 @@ class TestTranscriptLifecycle:
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
assert b"follow-up question" in uploaded
assert b"follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
assert b"user message 0" in uploaded
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
async def test_lifecycle_stale_download_fills_gap(self):
"""When transcript covers fewer messages, gap is filled rather than rejected."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
# session has 5 msgs but stored transcript only covers 2 → gap filled.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=2,
mode="baseline",
)
upload_mock = AsyncMock(return_value=None)
@@ -603,20 +599,18 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
covers = await _load_prior_transcript(
covers, _ = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", upload_safe=covers)
is False
)
upload_mock.assert_not_awaited()
assert covers is True
# Gap was filled: 2 from transcript + 2 gap messages
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
@@ -629,14 +623,11 @@ class TestTranscriptLifecycle:
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, upload_safe=True)
is False
)
assert should_upload_transcript(user_id=None, upload_safe=True) is False
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → upload is safe; the turn writes the first snapshot."""
"""No prior session → upload is safe; the turn writes the first snapshot."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
@@ -649,18 +640,117 @@ class TestTranscriptLifecycle:
new=upload_mock,
),
):
upload_safe = await _load_prior_transcript(
upload_safe, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
session_messages=_make_session_messages("user"),
transcript_builder=builder,
)
# Nothing in GCS → upload is safe so the first baseline turn
# can write the initial snapshot.
# can write the initial transcript snapshot.
assert upload_safe is True
assert dl is None
assert (
should_upload_transcript(
user_id="user-1", upload_safe=upload_safe
)
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
is True
)
# ---------------------------------------------------------------------------
# _append_gap_to_builder
# ---------------------------------------------------------------------------
class TestAppendGapToBuilder:
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
def test_user_message_appended(self):
builder = TranscriptBuilder()
msgs = [ChatMessage(role="user", content="hello")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
assert builder.last_entry_type == "user"
def test_assistant_text_message_appended(self):
builder = TranscriptBuilder()
msgs = [
ChatMessage(role="user", content="q"),
ChatMessage(role="assistant", content="answer"),
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
assert "answer" in builder.to_jsonl()
def test_assistant_with_tool_calls_appended(self):
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-1",
"type": "function",
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "tool_use" in jsonl
assert "my_tool" in jsonl
assert "tc-1" in jsonl
def test_assistant_invalid_json_args_uses_empty_dict(self):
"""Malformed JSON in tool_call arguments falls back to {}."""
builder = TranscriptBuilder()
tool_call = {
"id": "tc-bad",
"type": "function",
"function": {"name": "bad_tool", "arguments": "not-json"},
}
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="assistant", content=None)]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "text" in jsonl
def test_tool_role_with_tool_call_id_appended(self):
"""Tool result messages are appended when tool_call_id is set."""
builder = TranscriptBuilder()
# Need a preceding assistant tool_use entry
builder.append_user("use tool")
builder.append_assistant(
content_blocks=[
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
]
)
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 3
assert "tool_result" in builder.to_jsonl()
def test_tool_role_without_tool_call_id_skipped(self):
"""Tool messages without tool_call_id are silently skipped."""
builder = TranscriptBuilder()
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 0
def test_tool_call_missing_function_key_uses_unknown_name(self):
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
builder = TranscriptBuilder()
# Tool call dict exists but 'function' sub-dict is missing entirely
msgs = [
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
]
_append_gap_to_builder(msgs, builder)
assert builder.entry_count == 1
jsonl = builder.to_jsonl()
assert "unknown" in jsonl

View File

@@ -23,7 +23,7 @@ if TYPE_CHECKING:
# Allowed base directory for the Read tool. Public so service.py can use it
# for sweep operations without depending on a private implementation detail.
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
# _projects_base() function.
# projects_base() function.
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))

View File

@@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.

View File

@@ -8,7 +8,7 @@ Cross-mode transcript flow
==========================
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
mode) read and write the same JSONL transcript store via
mode) read and write the same CLI session store via
``backend.copilot.transcript.upload_transcript`` /
``download_transcript``.
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
@pytest.mark.asyncio
async def test_scenario_s_baseline_loads_sdk_transcript(self):
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
sdk_transcript = builder_sdk.to_jsonl()
# Baseline session now has those 2 SDK messages + 1 new baseline message.
download = TranscriptDownload(content=sdk_transcript, message_count=2)
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3, # 2 SDK + 1 new baseline
session_messages=[
ChatMessage(role="user", content="sdk-question"),
ChatMessage(role="assistant", content="sdk-answer"),
ChatMessage(role="user", content="baseline-question"),
],
transcript_builder=baseline_builder,
)
# Transcript is valid and covers the prefix.
# CLI session is valid and covers the prefix.
assert covers is True
assert dl is not None
assert baseline_builder.entry_count == 2
@pytest.mark.asyncio
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
If SDK mode produced more turns than the transcript captured (e.g.
upload failed on one turn), the baseline rejects the stale transcript
If SDK mode produced more turns than the session captured (e.g.
upload failed on one turn), the baseline rejects the stale session
to avoid injecting an incomplete history.
"""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
)
sdk_transcript = builder_sdk.to_jsonl()
# Transcript covers only 2 messages but session has 10 (many SDK turns).
download = TranscriptDownload(content=sdk_transcript, message_count=2)
# Session covers only 2 messages but session has 10 (many SDK turns).
# With watermark=2 and 10 total messages, detect_gap will fill the gap
# by appending messages 2..8 (positions 2 to total-2).
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
# Build a session with 10 alternating user/assistant messages + current user
session_messages = [
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
for i in range(10)
]
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
covers, dl = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=session_messages,
transcript_builder=baseline_builder,
)
# Stale transcript must be rejected.
assert covers is False
assert baseline_builder.is_empty
# With gap filling, covers is True and gap messages are appended.
assert covers is True
assert dl is not None
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
assert baseline_builder.entry_count == 9

View File

@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.transcript import (
TranscriptDownload,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
f"{_SVC}.download_transcript",
dict(
new_callable=AsyncMock,
return_value=MagicMock(content=original_transcript, message_count=2),
return_value=TranscriptDownload(
content=original_transcript.encode("utf-8"),
message_count=2,
mode="sdk",
),
),
),
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=True),
),
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.validate_transcript", dict(return_value=True)),
(
f"{_SVC}.compact_transcript",
@@ -1037,7 +1039,6 @@ def _make_sdk_patches(
claude_agent_fallback_model=None,
),
),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
]
@@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration:
compacted_transcript=None,
client_side_effect=_client_factory,
)
# Override restore_cli_session to return False (CLI native session unavailable)
# Override download_transcript to return None (CLI native session unavailable)
patches = [
(
(
f"{_SVC}.restore_cli_session",
dict(new_callable=AsyncMock, return_value=False),
f"{_SVC}.download_transcript",
dict(new_callable=AsyncMock, return_value=None),
)
if p[0] == f"{_SVC}.restore_cli_session"
if p[0] == f"{_SVC}.download_transcript"
else p
)
for p in patches
@@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration:
# captured_options holds {"options": ClaudeAgentOptions}, so check
# the attribute directly rather than dict keys.
assert not getattr(captured_options.get("options"), "resume", None), (
f"--resume was set even though restore_cli_session returned False: "
f"--resume was set even though download_transcript returned None: "
f"{captured_options}"
)
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -365,7 +365,7 @@ def create_security_hooks(
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
# Sanitize untrusted input: strip control chars for logging AND
# for the value passed downstream. read_compacted_entries()
# validates against _projects_base() as defence-in-depth, but
# validates against projects_base() as defence-in-depth, but
# sanitizing here prevents log injection and rejects obviously
# malformed paths early.
transcript_path = _sanitize(

View File

@@ -16,6 +16,7 @@ import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast
if TYPE_CHECKING:
@@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from ..transcript import (
_run_compression,
TranscriptDownload,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
download_transcript,
extract_context_messages,
projects_base,
read_compacted_entries,
restore_cli_session,
upload_cli_session,
strip_for_upload,
upload_transcript,
validate_transcript,
)
@@ -849,6 +853,181 @@ def _make_sdk_cwd(session_id: str) -> str:
return cwd
def _write_cli_session_to_disk(
content: bytes,
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> bool:
"""Write downloaded CLI session bytes to disk so the CLI can --resume.
Returns True on success, False if the path is invalid or the write fails.
Path-traversal guard: rejects paths outside the CLI projects base.
"""
session_file = cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
_pbase = projects_base()
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"%s CLI session restore path outside projects base: %s",
log_prefix,
os.path.basename(session_file),
)
return False
try:
os.makedirs(os.path.dirname(real_path), exist_ok=True)
Path(real_path).write_bytes(content)
logger.info(
"%s Wrote CLI session to disk (%dB) for --resume",
log_prefix,
len(content),
)
return True
except OSError as e:
logger.warning(
"%s Failed to write CLI session file %s: %s",
log_prefix,
os.path.basename(session_file),
e.strerror or str(e),
)
return False
def _read_cli_session_from_disk(
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> bytes | None:
"""Read the CLI session JSONL file from disk after the SDK turn.
Returns the file bytes, or None if the file is missing, outside the
projects base, or unreadable.
Path-traversal guard: rejects paths outside the CLI projects base.
"""
session_file = cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
_pbase = projects_base()
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"%s CLI session file outside projects base, skipping upload: %s",
log_prefix,
os.path.basename(real_path),
)
return None
try:
raw_bytes = Path(real_path).read_bytes()
except FileNotFoundError:
logger.debug(
"%s CLI session file not found, skipping upload: %s",
log_prefix,
os.path.basename(session_file),
)
return None
except OSError as e:
logger.warning(
"%s Failed to read CLI session file %s: %s",
log_prefix,
os.path.basename(session_file),
e.strerror or str(e),
)
return None
# Strip stale thinking blocks and metadata entries before uploading.
# Thinking blocks from non-last turns can be massive; keeping them causes
# the CLI to auto-compact its session when the context window fills up,
# silently losing conversation history.
try:
raw_text = raw_bytes.decode("utf-8")
stripped_text = strip_for_upload(raw_text)
stripped_bytes = stripped_text.encode("utf-8")
except UnicodeDecodeError:
logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix)
return raw_bytes
except (OSError, ValueError) as e:
# OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip.
# Other unexpected exceptions are not silently swallowed here so they propagate
# to the outer OSError handler and are logged with exc_info.
logger.warning(
"%s Failed to strip CLI session, uploading raw: %s", log_prefix, e
)
return raw_bytes
if len(stripped_bytes) < len(raw_bytes):
# Write back locally so same-pod turns also benefit.
try:
Path(real_path).write_bytes(stripped_bytes)
logger.info(
"%s Stripped CLI session: %dB → %dB",
log_prefix,
len(raw_bytes),
len(stripped_bytes),
)
except OSError as e:
# write_bytes failed — stripped content is still valid for GCS upload even
# though the local write-back failed (same-pod optimization silently skipped).
logger.warning(
"%s Failed to write back stripped CLI session: %s",
log_prefix,
e.strerror or str(e),
)
return stripped_bytes
def _process_cli_restore(
cli_restore: TranscriptDownload,
sdk_cwd: str,
session_id: str,
log_prefix: str,
) -> tuple[str, bool]:
"""Validate and write a restored CLI session to disk.
Decodes bytes → UTF-8, strips progress entries and stale thinking blocks,
validates the result, then writes the stripped content to disk so the CLI
can ``--resume`` from it.
Returns ``(stripped_content, success)`` where ``success=False`` means the
content was invalid or the disk write failed (caller should skip --resume).
"""
try:
raw_bytes = cli_restore.content
raw_str = (
raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes
)
except UnicodeDecodeError:
logger.warning(
"%s CLI session content is not valid UTF-8, skipping", log_prefix
)
return "", False
stripped = strip_for_upload(raw_str)
is_valid = validate_transcript(stripped)
# Use len(raw_str) rather than len(cli_restore.content) so the unit is always
# characters (raw_str is always str at this point regardless of input type).
# lines_stripped = original lines minus remaining lines after stripping.
_original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0
_remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0
logger.info(
"%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s",
log_prefix,
len(raw_str),
_original_lines - _remaining_lines,
cli_restore.message_count,
is_valid,
)
if not is_valid:
logger.warning(
"%s CLI session content invalid after strip — running without --resume",
log_prefix,
)
return "", False
stripped_bytes = stripped.encode("utf-8")
if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix):
return "", False
return stripped, True
async def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK session artifacts for a specific working directory.
@@ -922,8 +1101,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
result.append(block)
else:
logger.warning(
f"[SDK] Unknown content block type: {type(block).__name__}. "
f"This may indicate a new SDK version with additional block types."
"[SDK] Unknown content block type: %s."
" This may indicate a new SDK version with additional block types.",
type(block).__name__,
)
return result
@@ -978,10 +1158,11 @@ async def _compress_messages(
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
"[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
# Convert compressed dicts back to ChatMessages
return [
@@ -1048,11 +1229,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str:
)
if blocks:
builder.append_assistant(blocks)
elif msg.role == "tool" and msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
elif msg.role == "tool":
if msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
else:
# Malformed tool message — no tool_call_id to link to an
# assistant tool_use block. Skip to avoid an unmatched
# tool_result entry in the builder (which would confuse --resume).
logger.warning("[SDK] Skipping tool gap message with no tool_call_id")
return builder.to_jsonl()
@@ -1098,6 +1285,7 @@ async def _build_query_message(
transcript_msg_count: int,
session_id: str,
target_tokens: int | None = None,
prior_messages: "list[ChatMessage] | None" = None,
) -> tuple[str, bool]:
"""Build the query message with appropriate context.
@@ -1203,15 +1391,16 @@ async def _build_query_message(
)
return current_message, False
source = prior_messages if prior_messages is not None else prior
logger.warning(
"[SDK] [%s] No --resume for %d-message session — compressing"
" full session history (pod affinity issue or first turn after"
" restore failure); target_tokens=%s",
"[SDK] [%s] No --resume for %d-message session — compressing context "
"(source=%s, target_tokens=%s)",
session_id[:8],
msg_count,
"transcript+gap" if prior_messages is not None else "full-db",
target_tokens,
)
compressed, was_compressed = await _compress_messages(prior, target_tokens)
compressed, was_compressed = await _compress_messages(source, target_tokens)
history_context = _format_conversation_context(compressed)
if history_context:
logger.info(
@@ -1228,7 +1417,7 @@ async def _build_query_message(
"[SDK] [%s] Fallback context empty after compression"
" (%d messages) — sending message without history",
session_id[:8],
len(prior),
len(source),
)
return current_message, False
@@ -2233,6 +2422,163 @@ async def _seed_transcript(
return _seeded, True, len(_prior)
@dataclass
class _RestoreResult:
"""Return value from ``_restore_cli_session_for_turn``."""
transcript_content: str = ""
transcript_covers_prefix: bool = True
use_resume: bool = False
resume_file: str | None = None
transcript_msg_count: int = 0
baseline_download: "TranscriptDownload | None" = None
context_messages: "list[ChatMessage] | None" = None
async def _restore_cli_session_for_turn(
user_id: str | None,
session_id: str,
session: "ChatSession",
sdk_cwd: str,
transcript_builder: "TranscriptBuilder",
log_prefix: str,
) -> _RestoreResult:
"""Download, validate and restore a CLI session for ``--resume`` on this turn.
Performs a single GCS round-trip to fetch the session bytes + message_count
watermark. Falls back to DB-message reconstruction when GCS has no session
(first turn or upload missed).
Returns a ``_RestoreResult`` with all transcript-related state ready for the
caller to merge into its local variables.
"""
result = _RestoreResult()
if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1):
return result
try:
cli_restore = await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as restore_err:
logger.warning(
"%s CLI session restore failed, continuing without --resume: %s",
log_prefix,
restore_err,
)
cli_restore = None
# Only attempt --resume for SDK-written transcripts.
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
# stripped fields) that may not be valid for --resume.
if cli_restore is not None and cli_restore.mode != "sdk":
logger.info(
"%s Transcript written by mode=%r — skipping --resume, "
"will use transcript content + gap for context",
log_prefix,
cli_restore.mode,
)
result.baseline_download = cli_restore # keep for extract_context_messages
cli_restore = None
# Validate, strip, and write to disk — delegate to helper to reduce
# function complexity. Writing an invalid/corrupt file to disk then
# falling back to "no --resume" would cause the CLI to fail with
# "Session ID already in use" because the file exists at the expected
# session path, so we validate BEFORE any disk write.
stripped = ""
if cli_restore is not None and sdk_cwd:
stripped, ok = _process_cli_restore(
cli_restore, sdk_cwd, session_id, log_prefix
)
if not ok:
result.transcript_covers_prefix = False
cli_restore = None
if cli_restore is None and sdk_cwd:
# Validation failed or GCS returned no session. Delete any
# existing local session file so the CLI doesn't reject the
# session_id with "Session ID already in use". T1 may have
# left a valid file at this path; we clear it so the fallback
# path (session_id= without --resume) can create a new session.
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
if Path(_stale_path).exists() and _stale_path.startswith(
projects_base() + os.sep
):
try:
Path(_stale_path).unlink()
logger.debug(
"%s Removed stale local CLI session file for clean fallback",
log_prefix,
)
except OSError as _unlink_err:
logger.debug(
"%s Failed to remove stale local session file: %s",
log_prefix,
_unlink_err,
)
if cli_restore is not None:
result.transcript_content = stripped
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
result.use_resume = True
result.resume_file = session_id
result.transcript_msg_count = cli_restore.message_count
return result
# No valid --resume source (mode="baseline" or no GCS file).
# Build context from transcript content + gap, falling back to full DB.
# extract_context_messages handles both: non-None baseline_download uses
# the compacted transcript + gap; None falls back to all prior DB messages.
context_msgs = extract_context_messages(result.baseline_download, session.messages)
result.context_messages = context_msgs
result.transcript_msg_count = (
result.baseline_download.message_count
if result.baseline_download is not None
and result.baseline_download.message_count > 0
else len(session.messages) - 1
)
result.transcript_covers_prefix = True
logger.info(
"%s Context built from %s: %d messages (transcript watermark=%d, "
"will inject as <conversation_history>)",
log_prefix,
(
"baseline transcript + gap"
if result.baseline_download is not None
else "DB fallback"
),
len(context_msgs),
result.transcript_msg_count,
)
# Load baseline transcript content into builder so the upload path has accurate state.
# Also sets result.transcript_content so the _seed_transcript guard in the caller
# (``not transcript_content``) does not overwrite this builder state with a DB
# reconstruction — which would duplicate entries since load_previous appends.
if result.baseline_download is not None:
try:
raw_for_builder = result.baseline_download.content
if isinstance(raw_for_builder, bytes):
raw_for_builder = raw_for_builder.decode("utf-8")
stripped = strip_for_upload(raw_for_builder)
if validate_transcript(stripped):
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
result.transcript_content = stripped
except (UnicodeDecodeError, ValueError, OSError) as _load_err:
# UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in
# strip_for_upload; OSError: encode/decode I/O failure. Unexpected
# exceptions propagate so programming errors are not silently masked.
logger.debug(
"%s Could not load baseline transcript into builder: %s",
log_prefix,
_load_err,
)
return result
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
@@ -2427,28 +2773,9 @@ async def stream_chat_completion_sdk(
return sandbox
async def _fetch_transcript():
"""Download transcript for --resume if applicable."""
if not (
config.claude_agent_use_resume and user_id and len(session.messages) > 1
):
return None
try:
return await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as transcript_err:
logger.warning(
"%s Transcript download failed, continuing without --resume: %s",
log_prefix,
transcript_err,
)
return None
e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather(
e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather(
_setup_e2b(),
_build_system_prompt(user_id if not has_history else None),
_fetch_transcript(),
)
use_e2b = e2b_sandbox is not None
@@ -2473,95 +2800,17 @@ async def stream_chat_completion_sdk(
warm_ctx = await fetch_warm_context(user_id, message or "") or ""
# Process transcript download result and restore CLI native session.
# The CLI native session file (uploaded after each turn) is the
# source of truth for --resume. Our custom JSONL (TranscriptEntry)
# is loaded into the builder for future upload_transcript calls.
transcript_msg_count = 0
if dl:
is_valid = validate_transcript(dl.content)
dl_lines = dl.content.strip().split("\n") if dl.content else []
logger.info(
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
log_prefix,
len(dl.content),
len(dl_lines),
dl.message_count,
is_valid,
)
if is_valid:
# Load previous FULL context into builder for state tracking.
transcript_content = dl.content
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
# Restore CLI's native session file so --resume session_id works.
# Falls back gracefully if not available (first turn or upload missed).
# user_id is guaranteed non-None here: _fetch_transcript only sets dl
# when `config.claude_agent_use_resume and user_id` is truthy.
cli_restored = user_id is not None and await restore_cli_session(
user_id, session_id, sdk_cwd, log_prefix=log_prefix
)
if cli_restored:
use_resume = True
resume_file = session_id # CLI --resume expects UUID, not file path
transcript_msg_count = dl.message_count
logger.info(
"%s Using --resume %s (%dB transcript, msg_count=%d)",
log_prefix,
session_id[:8],
len(dl.content),
transcript_msg_count,
)
else:
# Builder loaded but CLI native session not available.
# --resume will not be used this turn; upload after turn
# will seed the native session for the next turn.
#
# Still record transcript_msg_count so _build_query_message
# can use the transcript-aware gap path (inject only new
# messages since the transcript end) instead of compressing
# the full DB history. This avoids prompt-too-long on
# large sessions where the CLI session is temporarily
# unavailable (e.g. mixed-version rolling deployment).
transcript_msg_count = dl.message_count
logger.info(
"%s CLI session not restored — running without"
" --resume this turn (transcript_msg_count=%d for"
" gap-aware fallback)",
log_prefix,
transcript_msg_count,
)
else:
logger.warning("%s Transcript downloaded but invalid", log_prefix)
transcript_covers_prefix = False
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
# No transcript in storage — reconstruct from DB messages as a
# last-resort fallback (e.g., first turn after a crash or transition).
# This path loses tool call IDs and structural fidelity but prevents
# a completely context-free response for established sessions.
prior = session.messages[:-1]
reconstructed = _session_messages_to_transcript(prior)
if reconstructed:
# Populate builder only; no --resume since there is no CLI
# native session to restore. The transcript builder state is
# still useful for the upload that seeds future native sessions.
transcript_content = reconstructed
transcript_builder.load_previous(reconstructed, log_prefix=log_prefix)
transcript_msg_count = len(prior)
transcript_covers_prefix = True
logger.info(
"%s Reconstructed transcript from %d session messages "
"(no CLI native session — running without --resume this turn)",
log_prefix,
len(prior),
)
else:
logger.warning(
"%s No transcript available and reconstruction produced empty"
" output (%d messages in session)",
log_prefix,
len(session.messages),
)
transcript_covers_prefix = False
# Restore CLI session — single GCS round-trip covers both --resume and builder state.
# message_count watermark lives in the companion .meta.json alongside the session file.
_restore = await _restore_cli_session_for_turn(
user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix
)
transcript_content = _restore.transcript_content
transcript_covers_prefix = _restore.transcript_covers_prefix
use_resume = _restore.use_resume
resume_file = _restore.resume_file
transcript_msg_count = _restore.transcript_msg_count
restore_context_messages = _restore.context_messages
yield StreamStart(messageId=message_id, sessionId=session_id)
@@ -2680,14 +2929,14 @@ async def stream_chat_completion_sdk(
else:
# Set session_id whenever NOT resuming so the CLI writes the
# native session file to a predictable path for
# upload_cli_session() after the turn. This covers:
# upload_transcript() after the turn. This covers:
# • T1 fresh: no prior history, first SDK turn.
# • Mode-switch T1: has_history=True (prior baseline turns in
# DB) but no CLI session file was ever uploaded — the CLI has
# never been invoked with this session_id before.
# • T2+ without --resume (restore failed): no session file was
# restored to local storage (restore_cli_session returned
# False), so no conflict with an existing file.
# restored to local storage (download_transcript returned
# None), so no conflict with an existing file.
# When --resume is active the session_id is already implied by
# the resume file; passing it again would be rejected by the CLI.
sdk_options_kwargs["session_id"] = session_id
@@ -2780,6 +3029,7 @@ async def stream_chat_completion_sdk(
use_resume,
transcript_msg_count,
session_id,
prior_messages=restore_context_messages,
)
# If files are attached, prepare them: images become vision
# content blocks in the user message, other files go to sdk_cwd.
@@ -2909,7 +3159,7 @@ async def stream_chat_completion_sdk(
elif "session_id" in sdk_options_kwargs:
# Initial invocation used session_id (T1 or mode-switch
# T1): keep it so the CLI writes the session file to the
# predictable path for upload_cli_session(). Storage is
# predictable path for upload_transcript(). Storage is
# ephemeral per invocation, so no "Session ID already in
# use" conflict occurs — no prior file was restored.
sdk_options_kwargs_retry.pop("resume", None)
@@ -2932,6 +3182,10 @@ async def stream_chat_completion_sdk(
system_prompt, cross_user_cache=_cross_user_retry
)
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
# Retry intentionally omits prior_messages (transcript+gap context) and
# falls back to full session.messages[:-1] from DB — the authoritative
# source. transcript+gap is an optimisation for the first attempt only;
# on retry the extra overhead of full-DB context is acceptable.
state.query_message, state.was_compacted = await _build_query_message(
current_message,
session,
@@ -3367,86 +3621,23 @@ async def stream_chat_completion_sdk(
_background_tasks.add(_ingest_task)
_ingest_task.add_done_callback(_background_tasks.discard)
# --- Upload transcript for next-turn --resume ---
# TranscriptBuilder is the single source of truth. It mirrors the
# CLI's active context: on compaction, replace_entries() syncs it
# with the compacted session file. No CLI file read needed here.
if skip_transcript_upload:
logger.warning(
"%s Skipping transcript upload — transcript was dropped "
"during prompt-too-long recovery",
log_prefix,
)
elif (
config.claude_agent_use_resume
and user_id
and session is not None
and state is not None
):
try:
transcript_upload_content = state.transcript_builder.to_jsonl()
entry_count = state.transcript_builder.entry_count
if not transcript_upload_content:
logger.warning(
"%s No transcript to upload (builder empty)", log_prefix
)
elif not validate_transcript(transcript_upload_content):
logger.warning(
"%s Transcript invalid, skipping upload (entries=%d)",
log_prefix,
entry_count,
)
elif not transcript_covers_prefix:
logger.warning(
"%s Skipping transcript upload — builder does not "
"cover full session prefix (entries=%d, session=%d)",
log_prefix,
entry_count,
len(session.messages),
)
else:
logger.info(
"%s Uploading transcript (entries=%d, bytes=%d)",
log_prefix,
entry_count,
len(transcript_upload_content),
)
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=transcript_upload_content,
message_count=len(session.messages),
log_prefix=log_prefix,
)
)
except Exception as upload_err:
logger.error(
"%s Transcript upload failed in finally: %s",
log_prefix,
upload_err,
exc_info=True,
)
# --- Upload CLI native session file for cross-pod --resume ---
# The CLI writes its native session JSONL after each turn completes.
# Uploading it here enables --resume on any pod (no pod affinity needed).
# Runs after upload_transcript so both are available for the next turn.
# asyncio.shield: same pattern as upload_transcript above — if the
# outer finally-block coroutine is cancelled while awaiting shield,
# the CancelledError propagates (BaseException, not caught by
# `except Exception`) letting the caller handle cancellation, while
# the shielded inner coroutine continues running to completion so the
# upload is not lost. This is intentional and matches the pattern
# used for upload_transcript immediately above.
# The companion .meta.json carries the message_count watermark and mode
# so the next turn can restore both --resume context and gap-fill state
# in a single GCS round-trip via download_transcript().
# asyncio.shield: if the outer finally-block coroutine is cancelled
# while awaiting shield, the CancelledError propagates (BaseException,
# not caught by `except Exception`) letting the caller handle
# cancellation, while the shielded inner coroutine continues running
# to completion so the upload is not lost.
#
# NOTE: upload is attempted regardless of state.use_resume — even when
# this turn ran without --resume (restore failed or first T2+ on a new
# pod), the T1 session file at the expected path may still be present
# and should be re-uploaded so the next turn can resume from it.
# upload_cli_session silently skips when the file is absent, so this is
# always safe.
# _read_cli_session_from_disk returns None when the file is absent, so
# this is always safe.
#
# Intentionally NOT gated on skip_transcript_upload: that flag is set
# when our custom JSONL transcript is dropped (transcript_lost=True on
@@ -3472,14 +3663,36 @@ async def stream_chat_completion_sdk(
skip_transcript_upload,
)
try:
await asyncio.shield(
upload_cli_session(
user_id=user_id,
session_id=session_id,
sdk_cwd=sdk_cwd,
log_prefix=log_prefix,
)
# Read the CLI's native session file from disk (written by the CLI
# after the turn), then upload the bytes to GCS.
_cli_content = _read_cli_session_from_disk(
sdk_cwd, session_id, log_prefix
)
if _cli_content:
# Watermark = number of DB messages this transcript covers.
# len(session.messages) is accurate: the CLI session file
# was just written after the turn completed, so it covers
# all messages through this turn. Any gap from a prior
# missed upload was already detected by detect_gap and
# injected as context, so the model has the full history.
#
# Previously this used _final_tmsg_count + 2, which
# under-counted for tool-use turns (delta = 2 + 2*N_tool_calls),
# causing persistent spurious gap-fills on every subsequent turn.
# That concern was addressed by the inflated-watermark fix
# (using the GCS watermark as the anchor for gap detection),
# which makes len(session.messages) safe to use here.
_jsonl_covered = len(session.messages)
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=_cli_content,
message_count=_jsonl_covered,
mode="sdk",
log_prefix=log_prefix,
)
)
except Exception as cli_upload_err:
logger.warning(
"%s CLI session upload failed in finally: %s",

View File

@@ -22,6 +22,7 @@ from .service import (
_iter_sdk_messages,
_normalize_model_name,
_reduce_context,
_restore_cli_session_for_turn,
_TokenUsage,
)
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
)
assert retry.get("resume") == self.SESSION_ID
assert "session_id" not in retry
# ---------------------------------------------------------------------------
# _restore_cli_session_for_turn — mode check
# ---------------------------------------------------------------------------
class TestRestoreCliSessionModeCheck:
"""SDK skips --resume when the transcript was written by the baseline mode."""
@pytest.mark.asyncio
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
"""A transcript with mode='baseline' must not be used as the --resume source.
The mode check discards the GCS baseline content and falls back to DB
reconstruction from session.messages instead.
"""
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hello-unique-marker"),
ChatMessage(role="assistant", content="world-unique-marker"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
# Baseline content with a sentinel that must NOT appear in the final transcript
baseline_restore = TranscriptDownload(
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
message_count=1,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
download_mock = AsyncMock(return_value=baseline_restore)
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=download_mock,
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
# download_transcript was called (attempted GCS restore)
download_mock.assert_awaited_once()
# use_resume must be False — baseline transcripts cannot be used with --resume
assert result.use_resume is False
# context_messages must be populated — new behaviour uses transcript content + gap
# instead of full DB reconstruction.
assert result.context_messages is not None
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
# Result: 1 message from transcript, no gap.
assert len(result.context_messages) == 1
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
@pytest.mark.asyncio
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
"""A valid SDK-written transcript is accepted for --resume."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "hello"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="hello"),
ChatMessage(role="user", content="follow up"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
sdk_restore = TranscriptDownload(
content=content,
message_count=2,
mode="sdk",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=sdk_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is True
@pytest.mark.asyncio
async def test_baseline_mode_context_messages_from_transcript_content(
self, tmp_path
):
"""mode='baseline' → context_messages populated from transcript content + gap.
When a baseline-mode transcript exists, extract_context_messages converts
the JSONL content to ChatMessage objects and returns them in context_messages.
use_resume must remain False.
"""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Build a minimal valid JSONL transcript with 2 messages
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER"),
ChatMessage(role="assistant", content="DB_ASSISTANT"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2,
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
assert len(result.context_messages) == 2
assert result.context_messages[0].role == "user"
assert result.context_messages[1].role == "assistant"
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
# transcript_content must be non-empty so the _seed_transcript guard in
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
# builder entries since load_previous appends).
assert result.transcript_content != ""
@pytest.mark.asyncio
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
import json as stdlib_json
from datetime import UTC, datetime
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
from backend.copilot.transcript_builder import TranscriptBuilder
# Transcript covers only 2 messages; session has 4 prior + current turn
lines = [
stdlib_json.dumps(
{
"type": "user",
"uuid": "uid-0",
"parentUuid": "",
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
}
),
stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-0",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
},
}
),
]
content = ("\n".join(lines) + "\n").encode("utf-8")
session = ChatSession(
session_id="test-session",
user_id="user-1",
messages=[
ChatMessage(role="user", content="DB_USER_0"),
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
ChatMessage(role="user", content="GAP_USER_2"),
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
ChatMessage(role="user", content="current turn"),
],
title="test",
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
builder = TranscriptBuilder()
baseline_restore = TranscriptDownload(
content=content,
message_count=2, # watermark=2; session has 4 prior → gap of 2
mode="baseline",
)
import backend.copilot.sdk.service as _svc_mod
with (
patch(
"backend.copilot.sdk.service.download_transcript",
new=AsyncMock(return_value=baseline_restore),
),
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
):
result = await _restore_cli_session_for_turn(
user_id="user-1",
session_id="test-session",
session=session,
sdk_cwd=str(tmp_path),
transcript_builder=builder,
log_prefix="[Test]",
)
assert result.use_resume is False
assert result.context_messages is not None
# 2 from transcript + 2 gap messages = 4 total
assert len(result.context_messages) == 4
roles = [m.role for m in result.context_messages]
assert roles == ["user", "assistant", "user", "assistant"]
# Gap messages come from DB (ChatMessage objects)
gap_user = result.context_messages[2]
gap_asst = result.context_messages[3]
assert gap_user.content == "GAP_USER_2"
assert gap_asst.content == "GAP_ASSISTANT_3"

View File

@@ -0,0 +1,95 @@
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
recorded) instead of len(session.messages). This prevents the "inflated
watermark" bug where a stale JSONL in GCS could hide missing context from
future gap-fill checks.
"""
from __future__ import annotations
def _compute_jsonl_covered(
use_resume: bool,
transcript_msg_count: int,
session_msg_count: int,
) -> int:
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
Extracted here so we can unit-test it independently without invoking the
full streaming stack.
"""
if use_resume and transcript_msg_count > 0:
return transcript_msg_count + 2
return session_msg_count
class TestWatermarkFix:
"""Watermark computation logic — mirrors the finally-block in SDK service."""
def test_inflated_watermark_triggers_gap_fill(self):
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
never fires because 46 >= 47-1=46, so context loss is silent.
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
the model receives the missing turns.
"""
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
use_resume = True
transcript_msg_count = 12
session_msg_count = 47 # DB count (what old code used to set watermark)
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 14 # 12 + 2, NOT 47
# Verify: the gap check would fire on next turn
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
assert watermark < session_msg_count - 1
def test_no_false_positive_when_transcript_current(self):
"""Transcript current (watermark=46, DB=47) → gap stays 0.
When the JSONL actually covers T46 (the most recent assistant turn),
uploading watermark=46+2=48 means next turn's gap check sees
48 >= 48-1=47 → no gap. Correct.
"""
use_resume = True
transcript_msg_count = 46
session_msg_count = 47
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == 48 # 46 + 2
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
next_turn_session = 48
assert watermark >= next_turn_session - 1
def test_fresh_session_falls_back_to_db_count(self):
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
use_resume = False
transcript_msg_count = 0
session_msg_count = 3
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count
def test_old_format_meta_zero_count_falls_back_to_db(self):
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
use_resume = True
transcript_msg_count = 0 # old-format meta or not-yet-set
session_msg_count = 10
watermark = _compute_jsonl_covered(
use_resume, transcript_msg_count, session_msg_count
)
assert watermark == session_msg_count

View File

@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
ENTRY_TYPE_MESSAGE,
STOP_REASON_END_TURN,
STRIPPABLE_TYPES,
TRANSCRIPT_STORAGE_PREFIX,
TranscriptDownload,
TranscriptMode,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
delete_transcript,
detect_gap,
download_transcript,
extract_context_messages,
projects_base,
read_compacted_entries,
restore_cli_session,
strip_for_upload,
strip_progress_entries,
strip_stale_thinking_blocks,
upload_cli_session,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
@@ -34,18 +36,20 @@ __all__ = [
"ENTRY_TYPE_MESSAGE",
"STOP_REASON_END_TURN",
"STRIPPABLE_TYPES",
"TRANSCRIPT_STORAGE_PREFIX",
"TranscriptDownload",
"TranscriptMode",
"cleanup_stale_project_dirs",
"cli_session_path",
"compact_transcript",
"delete_transcript",
"detect_gap",
"download_transcript",
"extract_context_messages",
"projects_base",
"read_compacted_entries",
"restore_cli_session",
"strip_for_upload",
"strip_progress_entries",
"strip_stale_thinking_blocks",
"upload_cli_session",
"upload_transcript",
"validate_transcript",
"write_transcript_to_tempfile",

View File

@@ -297,8 +297,8 @@ class TestStripProgressEntries:
class TestDeleteTranscript:
@pytest.mark.asyncio
async def test_deletes_both_jsonl_and_meta(self):
"""delete_transcript removes both the .jsonl and .meta.json files."""
async def test_deletes_cli_session_and_meta(self):
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock()
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
):
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
assert any(p.endswith(".jsonl") for p in paths)
assert any(p.endswith(".meta.json") for p in paths)
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
"""If .jsonl delete fails, .meta.json delete is still attempted."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[Exception("jsonl delete failed"), None, None]
side_effect=[Exception("jsonl delete failed"), None]
)
with patch(
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
# Should not raise
await delete_transcript("user-123", "session-456")
assert mock_storage.delete.call_count == 3
assert mock_storage.delete.call_count == 2
@pytest.mark.asyncio
async def test_handles_meta_delete_failure(self):
"""If .meta.json delete fails, no exception propagates."""
mock_storage = AsyncMock()
mock_storage.delete = AsyncMock(
side_effect=[None, Exception("meta delete failed"), None]
side_effect=[None, Exception("meta delete failed")]
)
with patch(
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: nonexistent,
)
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.transcript._projects_base",
"backend.copilot.transcript.projects_base",
lambda: str(projects_dir),
)
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
# Both entries of last turn (msg_last) preserved
assert lines[1]["message"]["content"][0]["type"] == "thinking"
assert lines[2]["message"]["content"][0]["type"] == "text"
class TestProcessCliRestore:
"""``_process_cli_restore`` validates, strips, and writes CLI session to disk."""
def test_writes_stripped_bytes_not_raw(self, tmp_path):
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
import os
import re
from pathlib import Path
from unittest.mock import patch
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.transcript import TranscriptDownload
session_id = "12345678-0000-0000-0000-abcdef000001"
sdk_cwd = str(tmp_path)
projects_base_dir = str(tmp_path)
# Build raw content with a strippable progress entry + a valid user/assistant pair
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
raw_bytes = raw_content.encode("utf-8")
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
stripped_str, ok = _process_cli_restore(
restore, sdk_cwd, session_id, "[Test]"
)
assert ok, "Expected successful restore"
# Find the written session file
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
assert session_file.exists(), "Session file should have been written"
written_bytes = session_file.read_bytes()
# The written bytes must be the stripped version (no progress entry)
assert (
b"progress" not in written_bytes
), "Raw bytes with progress entry should not have been written"
assert (
b"hello" in written_bytes
), "Stripped content should still contain assistant turn"
# Written bytes must equal the stripped string re-encoded
assert written_bytes == stripped_str.encode(
"utf-8"
), "Written bytes must equal stripped content"
def test_invalid_content_returns_false(self):
"""Content that fails validation after strip returns (empty, False)."""
from backend.copilot.sdk.service import _process_cli_restore
from backend.copilot.transcript import TranscriptDownload
# A single progress-only entry — stripped result will be empty/invalid
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
restore = TranscriptDownload(
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
)
stripped_str, ok = _process_cli_restore(
restore,
"/tmp/nonexistent-sdk-cwd",
"12345678-0000-0000-0000-000000000099",
"[Test]",
)
assert not ok
assert stripped_str == ""
class TestReadCliSessionFromDisk:
"""``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
def _build_session_file(self, tmp_path, session_id: str):
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
import os
import re
from pathlib import Path
sdk_cwd = str(tmp_path)
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = Path(str(tmp_path)) / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
return sdk_cwd, session_dir / f"{session_id}.jsonl"
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
from unittest.mock import patch
from backend.copilot.sdk.service import _read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0001"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Write raw invalid UTF-8 bytes
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
assert result == b"\xff\xfe invalid utf-8\n"
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
from unittest.mock import patch
from backend.copilot.sdk.service import _read_cli_session_from_disk
session_id = "12345678-0000-0000-0000-aabbccdd0002"
projects_base_dir = str(tmp_path)
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
# Content with a strippable progress entry so stripped_bytes < raw_bytes
raw_content = (
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
)
session_file.write_bytes(raw_content.encode("utf-8"))
# Make the file read-only so write_bytes raises OSError on the write-back
session_file.chmod(0o444)
try:
with (
patch(
"backend.copilot.sdk.service.projects_base",
return_value=projects_base_dir,
),
patch(
"backend.copilot.transcript.projects_base",
return_value=projects_base_dir,
),
):
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
finally:
session_file.chmod(0o644)
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
assert result is not None
assert (
b"progress" not in result
), "Stripped bytes must not contain progress entry"
assert b"hello" in result, "Stripped bytes should contain assistant turn"

View File

@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
# (CLI version, platform). When that happens, multi-turn still works
# via conversation compression (non-resume path), but we can't test
# the --resume round-trip.
transcript = None
cli_session = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
cli_session = await download_transcript(test_user_id, session.session_id)
# Wait until both the session bytes AND the message_count watermark are
# present — a session with message_count=0 means the .meta.json hasn't
# been uploaded yet, so --resume on the next turn would skip gap-fill.
if cli_session and cli_session.message_count > 0:
break
if not transcript:
if not cli_session:
return pytest.skip(
"CLI did not produce a usable transcript — "
"cannot test --resume round-trip in this environment"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
logger.info(
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
)
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)

View File

@@ -1,10 +1,10 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
bloat (progress entries, metadata), and uploads the result to bucket storage.
On the next turn the caller downloads the bytes and writes them to disk before
passing ``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
@@ -20,6 +20,7 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from uuid import uuid4
from backend.util import json
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
if TYPE_CHECKING:
from .model import ChatMessage
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
)
TranscriptMode = Literal["sdk", "baseline"]
@dataclass
class TranscriptDownload:
"""Result of downloading a transcript with its metadata."""
content: str
message_count: int = 0 # session.messages length when uploaded
uploaded_at: float = 0.0 # epoch timestamp of upload
content: bytes | str
message_count: int = 0
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
mode: TranscriptMode = "sdk"
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _projects_base() -> str:
def projects_base() -> str:
"""Return the resolved path to the CLI's projects directory."""
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
return os.path.realpath(os.path.join(config_dir, "projects"))
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
Returns the number of directories removed.
"""
projects_base = _projects_base()
if not os.path.isdir(projects_base):
_pbase = projects_base()
if not os.path.isdir(_pbase):
return 0
now = time.time()
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Scoped mode: only clean up the one directory for the current session.
if encoded_cwd:
target = Path(projects_base) / encoded_cwd
target = Path(_pbase) / encoded_cwd
if not target.is_dir():
return 0
# Guard: only sweep copilot-generated dirs.
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
# Only safe for single-tenant deployments; callers should prefer the
# scoped variant by passing encoded_cwd.
try:
entries = Path(projects_base).iterdir()
entries = Path(_pbase).iterdir()
except OSError as e:
logger.warning("[Transcript] Failed to list projects dir: %s", e)
return 0
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
if not transcript_path:
return None
projects_base = _projects_base()
_pbase = projects_base()
real_path = os.path.realpath(transcript_path)
if not real_path.startswith(projects_base + os.sep):
if not real_path.startswith(_pbase + os.sep):
logger.warning(
"[Transcript] transcript_path outside projects base: %s", transcript_path
)
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
wid, fid, fname = parts
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
return f"local://{wid}/{fid}/{fname}"
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects."""
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path for the companion .meta.json file."""
return _build_path_from_parts(
_meta_storage_path_parts(user_id, session_id), backend
)
# ---------------------------------------------------------------------------
# CLI native session file — cross-pod --resume support
# ---------------------------------------------------------------------------
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""Expected path of the CLI's native session JSONL file.
The CLI resolves the working directory via ``os.path.realpath``, then
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
"""
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
safe_id = _sanitize_id(session_id)
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
def _cli_session_storage_path_parts(
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
)
async def upload_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> None:
"""Upload the CLI's native session JSONL file to remote storage.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
The CLI only writes the session file after the turn completes, so this
must run in the finally block, AFTER the SDK stream has finished.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session file outside projects base, skipping upload: %s",
log_prefix,
os.path.basename(real_path),
)
return
try:
raw_bytes = Path(real_path).read_bytes()
except FileNotFoundError:
logger.debug(
"%s CLI session file not found, skipping upload: %s",
log_prefix,
session_file,
)
return
except OSError as e:
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
return
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
# queue-operation) from the CLI session before writing it back locally and uploading
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
# its session when the context window fills up. Stripping keeps the session well below
# the ~200K-token compaction threshold and prevents silent context loss.
try:
raw_text = raw_bytes.decode("utf-8")
stripped_text = strip_for_upload(raw_text)
stripped_bytes = stripped_text.encode("utf-8")
if len(stripped_bytes) < len(raw_bytes):
# Write the stripped version back locally so same-pod turns also benefit.
Path(real_path).write_bytes(stripped_bytes)
logger.info(
"%s Stripped CLI session file: %dB → %dB",
log_prefix,
len(raw_bytes),
len(stripped_bytes),
)
content = stripped_bytes
except Exception as e:
logger.warning(
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
)
content = raw_bytes
storage = await get_workspace_storage()
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
logger.info(
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
log_prefix,
len(content),
)
except Exception as e:
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
async def restore_cli_session(
user_id: str,
session_id: str,
sdk_cwd: str,
log_prefix: str = "[Transcript]",
) -> bool:
"""Download and restore the CLI's native session file for --resume.
Returns True if the file was successfully restored and --resume can be
used with the session UUID. Returns False if not available (first turn
or upload failed), in which case the caller should not set --resume.
"""
session_file = _cli_session_path(sdk_cwd, session_id)
real_path = os.path.realpath(session_file)
projects_base = _projects_base()
if not real_path.startswith(projects_base + os.sep):
logger.warning(
"%s CLI session restore path outside projects base: %s",
log_prefix,
os.path.basename(session_file),
)
return False
# If the session file already exists locally (same-pod reuse), use it directly.
# Downloading from storage could overwrite a newer local version when a previous
# turn's upload failed: stored content is stale while the local file already
# contains extended history from that turn.
if Path(real_path).exists():
logger.debug(
"%s CLI session file already exists locally — using it for --resume",
log_prefix,
)
return True
storage = await get_workspace_storage()
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
return (
_CLI_SESSION_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
try:
content = await storage.retrieve(path)
except FileNotFoundError:
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return False
except Exception as e:
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
return False
try:
os.makedirs(os.path.dirname(real_path), exist_ok=True)
Path(real_path).write_bytes(content)
logger.info(
"%s Restored CLI session file (%dB) for --resume",
log_prefix,
len(content),
)
return True
except OSError as e:
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
return False
async def upload_transcript(
user_id: str,
session_id: str,
content: str,
content: bytes,
message_count: int = 0,
mode: TranscriptMode = "sdk",
log_prefix: str = "[Transcript]",
skip_strip: bool = False,
) -> None:
"""Strip progress entries and stale thinking blocks, then upload transcript.
"""Upload CLI session content to GCS with companion meta.json.
The transcript represents the FULL active context (atomic).
Each upload REPLACES the previous transcript entirely.
Pure GCS operation — no disk I/O. The caller is responsible for reading
the session file from disk before calling this function.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen.
Also uploads a companion .meta.json with the message_count watermark so
download_transcript can return it without a separate fetch.
Args:
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
skip_strip: When ``True``, skip the strip + re-validate pass.
Safe for builder-generated content (baseline path) which
never emits progress entries or stale thinking blocks.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
"""
if skip_strip:
# Caller guarantees the content is already clean and valid.
stripped = content
else:
# Strip metadata entries and stale thinking blocks in a single parse.
# SDK-built transcripts may have progress entries; strip for safety.
stripped = strip_for_upload(content)
if not skip_strip and not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
for line in stripped.strip().split("\n")
]
logger.warning(
"%s Skipping upload — stripped content not valid "
"(types=%s, stripped_len=%d, raw_len=%d)",
log_prefix,
entry_types,
len(stripped),
len(content),
)
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
meta_encoded = json.dumps(meta).encode("utf-8")
# Transcript + metadata are independent objects at different keys, so
# write them concurrently. ``return_exceptions`` keeps a metadata
# failure from sinking the transcript write.
transcript_result, metadata_result = await asyncio.gather(
storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
),
storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=meta_encoded,
),
return_exceptions=True,
)
if isinstance(transcript_result, BaseException):
raise transcript_result
if isinstance(metadata_result, BaseException):
# Metadata is best-effort — the gap-fill logic in
# _build_query_message tolerates a missing metadata file.
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
# Write JSONL first, meta second — sequential so a crash between the two
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
# watermark / mode paired with stale or absent content).
# On any failure we roll back the other file so the pair is always absent
# together; download_transcript returns None when either file is missing.
try:
await storage.store(
workspace_id=wid, file_id=fid, filename=fname, content=content
)
except Exception as session_err:
logger.warning(
"%s Failed to upload CLI session file: %s", log_prefix, session_err
)
return
try:
await storage.store(
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
)
except Exception as meta_err:
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
# used with wrong mode/watermark defaults on the next restore.
try:
session_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
await storage.delete(session_path)
except Exception as rollback_err:
logger.debug(
"%s Session rollback failed (harmless — download will return None): %s",
log_prefix,
rollback_err,
)
return
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(encoded),
len(content),
message_count,
mode,
)
@@ -926,83 +743,173 @@ async def download_transcript(
session_id: str,
log_prefix: str = "[Transcript]",
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
Pure GCS operation — no disk I/O. The caller is responsible for writing
content to disk if --resume is needed.
The content and metadata fetches run concurrently since they are
independent objects in the bucket.
Returns a TranscriptDownload with the raw content, message_count watermark,
and mode on success, or None if not available (first turn or upload failed).
"""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
meta_path = _build_meta_storage_path(user_id, session_id, storage)
path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
)
meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
content_task = asyncio.create_task(storage.retrieve(path))
meta_task = asyncio.create_task(storage.retrieve(meta_path))
content_result, meta_result = await asyncio.gather(
content_task, meta_task, return_exceptions=True
storage.retrieve(path),
storage.retrieve(meta_path),
return_exceptions=True,
)
if isinstance(content_result, FileNotFoundError):
logger.debug("%s No transcript in storage", log_prefix)
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
return None
if isinstance(content_result, BaseException):
logger.warning(
"%s Failed to download transcript: %s", log_prefix, content_result
"%s Failed to download CLI session: %s", log_prefix, content_result
)
return None
content = content_result.decode("utf-8")
content: bytes = content_result
# Metadata is best-effort — old transcripts won't have it.
# Parse message_count and mode from companion meta best-effort, defaults.
message_count = 0
uploaded_at = 0.0
mode: TranscriptMode = "sdk"
if isinstance(meta_result, FileNotFoundError):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
pass # No meta — old upload; default to "sdk"
elif isinstance(meta_result, BaseException):
logger.debug(
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
)
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
else:
meta = json.loads(meta_result.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
try:
meta_str = meta_result.decode("utf-8")
except UnicodeDecodeError:
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
meta_str = None
if meta_str is not None:
meta = json.loads(meta_str, fallback={})
if isinstance(meta, dict):
raw_count = meta.get("message_count", 0)
message_count = (
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
)
raw_mode = meta.get("mode", "sdk")
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
)
return TranscriptDownload(
content=content,
message_count=message_count,
uploaded_at=uploaded_at,
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
log_prefix,
len(content),
message_count,
mode,
)
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
def detect_gap(
download: TranscriptDownload,
session_messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Return chat-db messages after the transcript watermark (excluding current user turn).
Returns [] if transcript is current, watermark is zero, or the watermark
position doesn't end on an assistant turn (misaligned watermark).
"""
if download.message_count == 0:
return []
wm = download.message_count
total = len(session_messages)
if wm >= total - 1:
return []
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
# In normal operation ``message_count`` is always written after a complete
# user→assistant exchange (never mid-turn), so the last covered position is
# always assistant. This guard fires only on data corruption or message deletion.
if session_messages[wm - 1].role != "assistant":
return []
return list(session_messages[wm : total - 1])
def extract_context_messages(
download: TranscriptDownload | None,
session_messages: "list[ChatMessage]",
) -> "list[ChatMessage]":
"""Return context messages for the current turn: transcript content + gap.
This is the shared context primitive used by both the SDK path
(``use_resume=False`` → ``<conversation_history>`` injection) and the
baseline path (OpenAI messages array).
How it works:
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
``isCompactSummary=True`` compaction entries, so the returned messages
mirror the compacted context the CLI would see via ``--resume``.
- The gap (DB messages after the transcript watermark) is always small in
normal operation; it only grows during mode switches or when an upload
was missed.
- Falls back to full DB messages when no transcript exists (first turn,
upload failure, or GCS unavailable).
- Returns *prior* messages only (excluding the current user turn at
``session_messages[-1]``). Callers that need the current turn append
``session_messages[-1]`` themselves.
- **Tool calls from transcript entries are flattened to text**: assistant
messages derived from the JSONL use ``_flatten_assistant_content``, which
serialises ``tool_use`` blocks as human-readable text rather than
structured ``tool_calls``. Gap messages (from DB) preserve their
original ``tool_calls`` field. This is the same trade-off as the old
``_compress_session_messages(session.messages)`` approach — no regression.
Args:
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
transcript is available. ``content`` may be either ``bytes`` or
``str`` (the baseline path decodes + strips before returning).
session_messages: All messages in the session, with the current user
turn as the last element.
Returns:
A list of ``ChatMessage`` objects covering the prior conversation
context, suitable for injection as conversation history.
"""
from .model import ChatMessage as _ChatMessage # runtime import
prior = session_messages[:-1]
if download is None:
return prior
raw_content = download.content
if not raw_content:
return prior
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
if isinstance(raw_content, bytes):
try:
content_str: str = raw_content.decode("utf-8")
except UnicodeDecodeError:
return prior
else:
content_str = raw_content
raw = _transcript_to_messages(content_str)
if not raw:
return prior
transcript_msgs = [
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
]
gap = detect_gap(download, session_messages)
return transcript_msgs + gap
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript and its metadata from bucket storage.
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
"""
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info("[Transcript] Deleted transcript for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete transcript: %s", e)
# Also delete the companion .meta.json to avoid orphaned metadata.
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
await storage.delete(meta_path)
logger.info("[Transcript] Deleted metadata for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete metadata: %s", e)
# Also delete the CLI native session file to prevent storage growth.
try:
cli_path = _build_path_from_parts(
_cli_session_storage_path_parts(user_id, session_id), storage
@@ -1012,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
try:
cli_meta_path = _build_path_from_parts(
_cli_session_meta_path_parts(user_id, session_id), storage
)
await storage.delete(cli_meta_path)
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
except Exception as e:
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
# ---------------------------------------------------------------------------
# Transcript compaction — LLM summarization for prompt-too-long recovery

View File

@@ -16,11 +16,11 @@ from .transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_meta_storage_path_parts,
_rechain_tail,
_sanitize_id,
_storage_path_parts,
_transcript_to_messages,
detect_gap,
extract_context_messages,
strip_for_upload,
validate_transcript,
)
@@ -64,24 +64,6 @@ class TestSanitizeId:
assert _sanitize_id("!@#$%^&*()") == "unknown"
# ---------------------------------------------------------------------------
# _storage_path_parts / _meta_storage_path_parts
# ---------------------------------------------------------------------------
class TestStoragePathParts:
def test_returns_triple(self):
prefix, uid, fname = _storage_path_parts("user-1", "sess-2")
assert prefix == "chat-transcripts"
assert "e" in uid # hex chars from "user-1" sanitized
assert fname.endswith(".jsonl")
def test_meta_returns_meta_json(self):
prefix, _, fname = _meta_storage_path_parts("user-1", "sess-2")
assert prefix == "chat-transcripts"
assert fname.endswith(".meta.json")
# ---------------------------------------------------------------------------
# _build_path_from_parts
# ---------------------------------------------------------------------------
@@ -103,24 +85,6 @@ class TestBuildPathFromParts:
assert path == "local://wid/fid/file.jsonl"
# ---------------------------------------------------------------------------
# TranscriptDownload dataclass
# ---------------------------------------------------------------------------
class TestTranscriptDownload:
def test_defaults(self):
td = TranscriptDownload(content="hello")
assert td.content == "hello"
assert td.message_count == 0
assert td.uploaded_at == 0.0
def test_custom_values(self):
td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45)
assert td.message_count == 5
assert td.uploaded_at == 123.45
# ---------------------------------------------------------------------------
# _flatten_assistant_content
# ---------------------------------------------------------------------------
@@ -733,190 +697,188 @@ class TestValidateTranscript:
class TestCliSessionPath:
def test_encodes_slashes_to_dashes(self):
from .transcript import _cli_session_path, _projects_base
from .transcript import cli_session_path, projects_base
sdk_cwd = "/tmp/copilot-abc"
result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc")
base = _projects_base()
result = cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc")
base = projects_base()
assert result.startswith(base)
# Encoded cwd replaces '/' with '-'
assert "-tmp-copilot-abc" in result
assert result.endswith(".jsonl")
def test_sanitizes_session_id(self):
from .transcript import _cli_session_path
from .transcript import cli_session_path
result = _cli_session_path("/tmp/cwd", "../../etc/passwd")
result = cli_session_path("/tmp/cwd", "../../etc/passwd")
# _sanitize_id strips non-hex/hyphen chars; path traversal impossible
assert ".." not in result
assert "passwd" not in result
class TestUploadCliSession:
def test_skips_upload_when_path_outside_projects_base(self, tmp_path):
"""Files outside the CLI projects base are rejected without upload."""
def test_uploads_content_bytes_successfully(self):
"""Happy path: content bytes are stored as jsonl + meta.json."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import upload_cli_session
from .transcript import upload_transcript
mock_storage = AsyncMock()
content = b'{"type":"assistant"}\n'
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=str(tmp_path),
),
# Return a path that is genuinely outside tmp_path so that
# realpath(session_file).startswith(projects_base + "/") is False
# and the boundary guard actually fires.
patch(
"backend.copilot.transcript._cli_session_path",
return_value="/outside/escaped/session.jsonl",
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
asyncio.run(
upload_cli_session(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000000",
sdk_cwd=str(tmp_path),
session_id="12345678-0000-0000-0000-000000000001",
content=content,
)
)
# storage.store must NOT be called — boundary guard should reject the path
mock_storage.store.assert_not_called()
# Two calls expected: session JSONL + companion .meta.json
assert mock_storage.store.call_count == 2
def test_skips_upload_when_file_not_found(self, tmp_path):
"""Missing CLI session file logs debug and skips upload silently."""
def test_uploads_companion_meta_json_with_message_count(self):
"""upload_transcript stores a companion .meta.json with message_count."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import upload_transcript
mock_storage = AsyncMock()
content = b'{"type":"assistant"}\n'
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
asyncio.run(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000010",
content=content,
message_count=5,
)
)
assert mock_storage.store.call_count == 2
# Find the meta.json store call
meta_call = next(
c
for c in mock_storage.store.call_args_list
if c.kwargs.get("filename", "").endswith(".meta.json")
)
meta_content = json.loads(meta_call.kwargs["content"])
assert meta_content["message_count"] == 5
def test_skips_upload_on_storage_failure(self):
"""Storage exception on jsonl write is logged and does not propagate.
With sequential writes, JSONL failure returns early — meta store is
never called, so no rollback is needed.
"""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import upload_cli_session
from .transcript import upload_transcript
mock_storage = AsyncMock()
projects_base = str(tmp_path)
mock_storage.store.side_effect = RuntimeError("gcs unavailable")
content = b'{"type":"assistant"}\n'
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
# session file doesn't existshould not raise
# Should not raisefailures are logged as warnings
asyncio.run(
upload_cli_session(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000000",
sdk_cwd=str(tmp_path),
)
)
mock_storage.store.assert_not_called()
def test_uploads_file_successfully(self, tmp_path):
"""Happy path: session file exists within projects base → upload called."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import _sanitize_id, upload_cli_session
projects_base = str(tmp_path)
session_id = "12345678-0000-0000-0000-000000000001"
sdk_cwd = str(tmp_path)
# Build the path the same way _cli_session_path does, but using our tmp_path
# as projects_base so the boundary check passes.
# Must use the same encoding: re.sub non-alphanumeric → "-" on realpath.
import os
import re
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = tmp_path / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
session_file.write_bytes(b'{"type":"assistant"}\n')
mock_storage = AsyncMock()
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
):
asyncio.run(
upload_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
session_id="12345678-0000-0000-0000-000000000002",
content=content,
)
)
# Only one store call attempted (the JSONL); meta never reached
mock_storage.store.assert_called_once()
mock_storage.delete.assert_not_called()
def test_skips_upload_on_oserror(self, tmp_path):
"""OSError reading session file is logged as warning; upload is skipped."""
def test_rolls_back_session_when_meta_upload_fails(self):
"""When meta upload fails after JSONL succeeds, JSONL is rolled back.
Guarantees the pair is either both present or both absent — avoids an
orphaned JSONL being used with wrong mode/watermark defaults.
"""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import _sanitize_id, upload_cli_session
projects_base = str(tmp_path)
sdk_cwd = str(tmp_path)
session_id = "12345678-0000-0000-0000-000000000002"
# Build file at a path inside projects_base so boundary check passes.
import os
import re
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
session_dir = tmp_path / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
session_file.write_bytes(b'{"type":"assistant"}\n')
# Remove read permission to trigger OSError
session_file.chmod(0o000)
from .transcript import upload_transcript
mock_storage = AsyncMock()
# First store (JSONL) succeeds; second store (meta) fails
mock_storage.store.side_effect = [None, RuntimeError("meta write failed")]
content = b'{"type":"assistant"}\n'
try:
with (
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
):
asyncio.run(
upload_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
)
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
asyncio.run(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000099",
content=content,
)
finally:
session_file.chmod(0o644) # restore so tmp_path cleanup works
)
mock_storage.store.assert_not_called()
# Both store calls were attempted (JSONL then meta)
assert mock_storage.store.call_count == 2
# JSONL should be rolled back via delete
mock_storage.delete.assert_called_once()
def test_baseline_mode_stored_in_meta(self):
"""upload_transcript with mode='baseline' stores mode in companion meta.json."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import upload_transcript
mock_storage = AsyncMock()
content = b'{"type":"assistant"}\n'
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
asyncio.run(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000098",
content=content,
message_count=4,
mode="baseline",
)
)
meta_call = next(
c
for c in mock_storage.store.call_args_list
if c.kwargs.get("filename", "").endswith(".meta.json")
)
meta_content = json.loads(meta_call.kwargs["content"])
assert meta_content["mode"] == "baseline"
assert meta_content["message_count"] == 4
def test_strips_session_before_upload_and_writes_back(self, tmp_path):
"""Strippable entries (progress, thinking blocks) are removed before upload.
@@ -1116,15 +1078,18 @@ class TestUploadCliSession:
class TestRestoreCliSession:
def test_returns_false_when_file_not_found_in_storage(self):
"""Returns False (graceful degradation) when the session is missing."""
def test_returns_none_when_file_not_found_in_storage(self):
"""Returns None (graceful degradation) when the session is missing."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = FileNotFoundError("not found")
mock_storage.retrieve.side_effect = [
FileNotFoundError("no session"),
FileNotFoundError("no meta"),
]
with patch(
"backend.copilot.transcript.get_workspace_storage",
@@ -1132,144 +1097,26 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000000",
sdk_cwd="/tmp/copilot-test",
)
)
assert result is False
assert result is None
def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path):
"""Path traversal guard: rejects restoration outside the projects base."""
def test_returns_transcript_download_on_success_no_meta(self):
"""Happy path with no meta.json: returns TranscriptDownload with message_count=0."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = b'{"type":"assistant"}\n'
with (
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
patch(
"backend.copilot.transcript._projects_base",
return_value=str(tmp_path),
),
# Return a path genuinely outside tmp_path so the boundary guard fires.
patch(
"backend.copilot.transcript._cli_session_path",
return_value="/outside/escaped/session.jsonl",
),
):
result = asyncio.run(
restore_cli_session(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000000",
sdk_cwd=str(tmp_path),
)
)
assert result is False
def test_returns_true_when_local_file_already_exists(self, tmp_path):
"""Same-pod reuse: if local file exists, skip storage download and return True."""
import asyncio
import os
import re
from pathlib import Path
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
session_id = "12345678-0000-0000-0000-000000000099"
sdk_cwd = str(tmp_path)
# Pre-create the local session file (simulates previous turn on same pod)
projects_base = os.path.realpath(str(tmp_path))
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", projects_base)
session_dir = Path(projects_base) / encoded_cwd
session_dir.mkdir(parents=True, exist_ok=True)
existing_content = b'{"type":"user"}\n{"type":"assistant"}\n'
(session_dir / f"{session_id}.jsonl").write_bytes(existing_content)
mock_storage = AsyncMock()
with (
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
):
result = asyncio.run(
restore_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
)
)
assert result is True
# Storage should NOT have been accessed (local file was used as-is)
mock_storage.retrieve.assert_not_called()
# Local file should be unchanged
assert (session_dir / f"{session_id}.jsonl").read_bytes() == existing_content
def test_returns_true_on_success(self, tmp_path):
"""Happy path: storage has the session → file written → returns True."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
projects_base = str(tmp_path)
sdk_cwd = str(tmp_path)
session_id = "12345678-0000-0000-0000-000000000003"
content = b'{"type":"assistant"}\n'
mock_storage = AsyncMock()
mock_storage.retrieve.return_value = content
with (
patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
),
patch(
"backend.copilot.transcript._projects_base",
return_value=projects_base,
),
):
result = asyncio.run(
restore_cli_session(
user_id="user-1",
session_id=session_id,
sdk_cwd=sdk_cwd,
)
)
assert result is True
def test_returns_false_on_download_exception(self):
"""Non-FileNotFoundError during retrieve logs warning and returns False."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = RuntimeError("network error")
mock_storage.retrieve.side_effect = [content, FileNotFoundError("no meta")]
with patch(
"backend.copilot.transcript.get_workspace_storage",
@@ -1277,11 +1124,411 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000004",
sdk_cwd="/tmp/copilot-test",
session_id=session_id,
)
)
assert result is False
assert isinstance(result, TranscriptDownload)
assert result.content == content
assert result.message_count == 0
assert result.mode == "sdk"
def test_returns_transcript_download_with_message_count_from_meta(self):
"""When meta.json is present, message_count and mode are read from it."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
session_id = "12345678-0000-0000-0000-000000000005"
content = b'{"type":"assistant"}\n'
meta_bytes = json.dumps(
{"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0}
).encode()
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, meta_bytes]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id=session_id,
)
)
assert isinstance(result, TranscriptDownload)
assert result.content == content
assert result.message_count == 7
assert result.mode == "sdk"
def test_returns_none_on_download_exception(self):
"""Non-FileNotFoundError during retrieve logs warning and returns None."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [
RuntimeError("network error"),
FileNotFoundError("no meta"),
]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000004",
)
)
assert result is None
def test_baseline_mode_in_meta_returned(self):
"""When meta.json contains mode='baseline', result.mode is 'baseline'."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
content = b'{"type":"assistant"}\n'
meta_bytes = json.dumps(
{"message_count": 3, "mode": "baseline", "uploaded_at": 0.0}
).encode()
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, meta_bytes]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000020",
)
)
assert isinstance(result, TranscriptDownload)
assert result.mode == "baseline"
assert result.message_count == 3
def test_invalid_mode_in_meta_defaults_to_sdk(self):
"""Unknown mode value in meta.json falls back to 'sdk'."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
content = b'{"type":"assistant"}\n'
meta_bytes = json.dumps({"message_count": 2, "mode": "unknown_mode"}).encode()
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, meta_bytes]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000021",
)
)
assert isinstance(result, TranscriptDownload)
assert result.mode == "sdk"
def test_invalid_utf8_meta_uses_defaults(self):
"""Meta bytes that fail UTF-8 decode fall back to message_count=0, mode='sdk'."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
content = b'{"type":"assistant"}\n'
bad_meta = b"\xff\xfe"
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, bad_meta]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000022",
)
)
assert isinstance(result, TranscriptDownload)
assert result.message_count == 0
assert result.mode == "sdk"
def test_meta_fetch_exception_uses_defaults(self):
"""Non-FileNotFoundError on meta fetch still returns content with defaults."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import download_transcript
content = b'{"type":"assistant"}\n'
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, RuntimeError("meta unavailable")]
with patch(
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
result = asyncio.run(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000023",
)
)
assert isinstance(result, TranscriptDownload)
assert result.content == content
assert result.message_count == 0
assert result.mode == "sdk"
# ---------------------------------------------------------------------------
# detect_gap
# ---------------------------------------------------------------------------
def _msgs(*roles: str):
"""Build a list of ChatMessage objects with the given roles."""
from .model import ChatMessage
return [ChatMessage(role=r, content=f"{r}-{i}") for i, r in enumerate(roles)]
class TestDetectGap:
"""``detect_gap`` returns messages between transcript watermark and current turn."""
def _dl(self, message_count: int) -> TranscriptDownload:
return TranscriptDownload(content=b"", message_count=message_count, mode="sdk")
def test_zero_watermark_returns_empty(self):
"""message_count=0 means no watermark — skip gap detection."""
dl = self._dl(0)
messages = _msgs("user", "assistant", "user")
assert detect_gap(dl, messages) == []
def test_watermark_covers_all_prefix_returns_empty(self):
"""Transcript already covers all messages up to the current user turn."""
# session: [user, assistant, user(current)] — wm=2 means covers up to assistant
dl = self._dl(2)
messages = _msgs("user", "assistant", "user")
assert detect_gap(dl, messages) == []
def test_watermark_exceeds_session_returns_empty(self):
"""Watermark ahead of session count (race / over-count) → no gap."""
dl = self._dl(10)
messages = _msgs("user", "assistant", "user")
assert detect_gap(dl, messages) == []
def test_misaligned_watermark_not_on_assistant_returns_empty(self):
"""Watermark at a user-role position is misaligned — skip gap."""
# wm=1: position 0 is 'user', not 'assistant' → skip
dl = self._dl(1)
messages = _msgs("user", "assistant", "user", "assistant", "user")
assert detect_gap(dl, messages) == []
def test_returns_gap_messages(self):
"""Watermark behind session — gap messages returned (excluding current turn)."""
# session: [user0, assistant1, user2, assistant3, user4(current)]
# wm=2: transcript covers [0,1]; gap = [user2, assistant3]
dl = self._dl(2)
messages = _msgs("user", "assistant", "user", "assistant", "user")
gap = detect_gap(dl, messages)
assert len(gap) == 2
assert gap[0].role == "user"
assert gap[1].role == "assistant"
def test_excludes_current_user_turn(self):
"""The last message (current user turn) is never included in the gap."""
# wm=2, session has 4 msgs: gap = [msg2] only (msg3 is current turn → excluded)
dl = self._dl(2)
messages = _msgs("user", "assistant", "user", "user")
gap = detect_gap(dl, messages)
assert len(gap) == 1
assert gap[0].role == "user"
def test_single_gap_message(self):
"""One message between watermark and current turn."""
# session: [user0, assistant1, user2, assistant3, user4(current)]
# wm=3: position 2 is 'user' → misaligned, returns []
# use wm=4: but 4 >= total-1=4 → also empty
# wm=3 with session [u, a, u, a, u, a, u(current)]: position 2 is 'user' → empty
# Valid case: wm=2 has 3 messages (assistant at 1), wm=4 with [u,a,u,a,u,a,u]:
# let's use wm=4 with 7 messages: wm=4 >= total-1=6? no, 4<6. pos[3]=assistant → gap=[msg4,msg5]
# simpler: wm=2, [u0,a1,a2,u3(current)] — pos[1]=assistant, gap=[a2] only
dl = self._dl(2)
messages = _msgs("user", "assistant", "assistant", "user")
gap = detect_gap(dl, messages)
assert len(gap) == 1
assert gap[0].role == "assistant"
# ---------------------------------------------------------------------------
# extract_context_messages
# ---------------------------------------------------------------------------
def _make_valid_transcript(*roles: str) -> str:
"""Build a minimal valid JSONL transcript with the given message roles."""
import json as stdlib_json
from .transcript import STOP_REASON_END_TURN
lines = []
parent = ""
for i, role in enumerate(roles):
uid = f"uid-{i}"
entry: dict = {
"type": role,
"uuid": uid,
"parentUuid": parent,
"message": {
"role": role,
"content": f"{role} content {i}",
},
}
if role == "assistant":
entry["message"]["id"] = f"msg_{i}"
entry["message"]["model"] = "test-model"
entry["message"]["type"] = "message"
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
entry["message"]["content"] = [
{"type": "text", "text": f"assistant content {i}"}
]
lines.append(stdlib_json.dumps(entry))
parent = uid
return "\n".join(lines) + "\n"
class TestExtractContextMessages:
"""``extract_context_messages`` returns the shared context primitive."""
def test_none_download_returns_prior(self):
"""No download → falls back to all session messages except current turn."""
messages = _msgs("user", "assistant", "user")
result = extract_context_messages(None, messages)
assert result == messages[:-1]
assert len(result) == 2
def test_empty_content_download_returns_prior(self):
"""Empty bytes content → falls back to all prior messages."""
dl = TranscriptDownload(content=b"", message_count=2, mode="sdk")
messages = _msgs("user", "assistant", "user")
result = extract_context_messages(dl, messages)
assert result == messages[:-1]
def test_valid_transcript_no_gap_returns_transcript_messages(self):
"""Transcript covers all prior turns → only transcript messages returned."""
# Transcript: [user, assistant] — 2 messages
# Session: [user, assistant, user(current)] — watermark=2 covers prefix
transcript_content = _make_valid_transcript("user", "assistant")
dl = TranscriptDownload(
content=transcript_content.encode("utf-8"), message_count=2, mode="sdk"
)
messages = _msgs("user", "assistant", "user")
result = extract_context_messages(dl, messages)
# Transcript has 2 messages (user + assistant) and no gap
assert len(result) == 2
assert result[0].role == "user"
assert result[1].role == "assistant"
def test_valid_transcript_with_gap_returns_transcript_plus_gap(self):
"""Transcript is stale → gap messages appended after transcript content."""
# Transcript: [user, assistant] — watermark=2
# Session: [user, assistant, user, assistant, user(current)]
# Gap: [user(2), assistant(3)] — positions 2 and 3
transcript_content = _make_valid_transcript("user", "assistant")
dl = TranscriptDownload(
content=transcript_content.encode("utf-8"), message_count=2, mode="sdk"
)
messages = _msgs("user", "assistant", "user", "assistant", "user")
result = extract_context_messages(dl, messages)
# 2 transcript messages + 2 gap messages = 4
assert len(result) == 4
assert result[0].role == "user" # transcript user
assert result[1].role == "assistant" # transcript assistant
assert result[2].role == "user" # gap user
assert result[3].role == "assistant" # gap assistant
def test_compact_summary_entries_preserved(self):
"""``isCompactSummary=True`` entries survive ``_transcript_to_messages``."""
import json as stdlib_json
from .transcript import STOP_REASON_END_TURN
# Build a transcript where one entry is a compaction summary.
# isCompactSummary=True entries have type in STRIPPABLE_TYPES but are kept.
compact_entry = stdlib_json.dumps(
{
"type": "summary",
"uuid": "uid-compact",
"parentUuid": "",
"isCompactSummary": True,
"message": {
"role": "user",
"content": "COMPACT_SUMMARY_CONTENT",
},
}
)
assistant_entry = stdlib_json.dumps(
{
"type": "assistant",
"uuid": "uid-1",
"parentUuid": "uid-compact",
"message": {
"role": "assistant",
"id": "msg_1",
"model": "test",
"type": "message",
"stop_reason": STOP_REASON_END_TURN,
"content": [{"type": "text", "text": "response after compact"}],
},
}
)
content = compact_entry + "\n" + assistant_entry + "\n"
dl = TranscriptDownload(
content=content.encode("utf-8"), message_count=2, mode="sdk"
)
messages = _msgs("user", "assistant", "user")
result = extract_context_messages(dl, messages)
# Both the compact summary and the assistant response are present
assert len(result) == 2
roles = [m.role for m in result]
assert "user" in roles # compact summary has role=user
assert "assistant" in roles
# The compact summary content is preserved
compact_msgs = [m for m in result if m.role == "user"]
assert any("COMPACT_SUMMARY_CONTENT" in (m.content or "") for m in compact_msgs)

View File

@@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None:
print(f"[{sid[:12]}] Not found in GCS")
continue
content_str = (
dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content
)
out = _transcript_path(sid)
with open(out, "w") as f:
f.write(dl.content)
f.write(content_str)
lines = len(dl.content.strip().split("\n"))
lines = len(content_str.strip().split("\n"))
meta = {
"session_id": sid,
"user_id": user_id,
"message_count": dl.message_count,
"uploaded_at": dl.uploaded_at,
"transcript_bytes": len(dl.content),
"transcript_bytes": len(content_str),
"transcript_lines": lines,
}
with open(_meta_path(sid), "w") as f:
@@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None:
print(
f"[{sid[:12]}] Saved: {lines} entries, "
f"{len(dl.content)} bytes, msg_count={dl.message_count}"
f"{len(content_str)} bytes, msg_count={dl.message_count}"
)
print("\nDone. Run 'load' command to import into local dev environment.")
@@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None:
await upload_transcript(
user_id=user_id,
session_id=sid,
content=content,
content=content.encode("utf-8"),
message_count=msg_count,
)
print(f"[{sid[:12]}] Stored transcript in local workspace storage")

View File

@@ -0,0 +1,140 @@
"""Unit tests for the transcript watermark (message_count) fix.
The bug: upload used message_count=len(session.messages) (DB count). When a
prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g.
covered only T1-T12) but the meta.json watermark matched the full DB count
(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1)
never triggered, so the model silently lost context for the skipped turns.
The fix: watermark = previous_coverage + 2 (current user+asst pair) when
use_resume=True and transcript_msg_count > 0. This ensures the watermark
reflects the JSONL content, not the DB count.
These tests exercise _build_query_message directly to verify that gap-fill
triggers with the corrected watermark but NOT with the inflated (buggy) one.
"""
from unittest.mock import MagicMock
import pytest
from backend.copilot.sdk.service import _build_query_message
def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]:
"""Build a flat list of n_pairs*2 alternating user/asst messages, plus
one trailing user message for the *current* turn."""
msgs: list[MagicMock] = []
for i in range(n_pairs):
u = MagicMock()
u.role = "user"
u.content = f"user message {i}"
a = MagicMock()
a.role = "assistant"
a.content = f"assistant response {i}"
msgs.extend([u, a])
# Current turn's user message
cur = MagicMock()
cur.role = "user"
cur.content = current_user
msgs.append(cur)
return msgs
def _make_session(messages: list[MagicMock]) -> MagicMock:
session = MagicMock()
session.messages = messages
return session
@pytest.mark.asyncio
async def test_gap_fill_triggers_for_stale_jsonl():
"""Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs).
With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test').
Next turn (T24) downloads watermark=26, DB has 47.
Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23.
"""
# T23 turns in DB (46 messages) + T24 user = 47
msgs = _make_messages(23, current_user="memory test - recall all")
assert len(msgs) == 47
session = _make_session(msgs)
# Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26
result_msg, _ = await _build_query_message(
current_message="memory test - recall all",
session=session,
use_resume=True,
transcript_msg_count=26,
session_id="test-session-id",
)
assert "<conversation_history>" in result_msg, (
"Expected gap-fill to inject <conversation_history> when "
"watermark=26 < msg_count-1=46"
)
@pytest.mark.asyncio
async def test_no_gap_fill_when_watermark_is_current():
"""When the JSONL is fully current (watermark = DB-1), no gap injected."""
# T23 turns in DB (46 messages) + T24 user = 47
msgs = _make_messages(23, current_user="next message")
session = _make_session(msgs)
result_msg, _ = await _build_query_message(
current_message="next message",
session=session,
use_resume=True,
transcript_msg_count=46, # current — no gap
session_id="test-session-id",
)
assert (
"<conversation_history>" not in result_msg
), "No gap-fill expected when watermark is current"
assert result_msg == "next message"
@pytest.mark.asyncio
async def test_inflated_watermark_suppresses_gap_fill():
"""Documents the original bug: inflated watermark suppresses gap-fill.
'Test' uploaded watermark=len(session.messages)=46 even though only 26
messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill.
"""
msgs = _make_messages(23, current_user="memory test")
session = _make_session(msgs)
# Buggy watermark: inflated to DB count
result_msg, _ = await _build_query_message(
current_message="memory test",
session=session,
use_resume=True,
transcript_msg_count=46, # inflated — suppresses gap fill
session_id="test-session-id",
)
assert (
"<conversation_history>" not in result_msg
), "With inflated watermark, gap-fill is suppressed — this documents the bug"
@pytest.mark.asyncio
async def test_fixed_watermark_fills_same_gap():
"""Same scenario but with the FIXED watermark triggers gap-fill."""
msgs = _make_messages(23, current_user="memory test")
session = _make_session(msgs)
result_msg, _ = await _build_query_message(
current_message="memory test",
session=session,
use_resume=True,
transcript_msg_count=26, # fixed watermark
session_id="test-session-id",
)
assert (
"<conversation_history>" in result_msg
), "With fixed watermark=26, gap-fill triggers and injects missing turns"