mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
chore: merge master into dev, resolve baseline/transcript conflicts
Conflicts in baseline/service.py, baseline/transcript_integration_test.py,
and transcript.py arose because dev-only commit 0cd0a76305
(baseline upload fix) overlapped with the same fix in PR #12804 which
landed in master. Took master's version for all three files — it is the
complete, reviewed implementation.
This commit is contained in:
@@ -67,11 +67,15 @@ from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
strip_for_upload,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util import json as util_json
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import (
|
||||
compress_context,
|
||||
@@ -699,29 +703,7 @@ async def _compress_session_messages(
|
||||
return messages
|
||||
|
||||
|
||||
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
|
||||
"""Return ``True`` when a download doesn't cover the current session.
|
||||
|
||||
A transcript is stale when it has a known ``message_count`` and that
|
||||
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
|
||||
already advanced beyond what the stored transcript captures).
|
||||
Loading a stale transcript would silently drop intermediate turns,
|
||||
so callers should treat stale as "skip load, skip upload".
|
||||
|
||||
An unknown ``message_count`` (``0``) is treated as **not stale**
|
||||
because older transcripts uploaded before msg_count tracking
|
||||
existed must still be usable.
|
||||
"""
|
||||
if dl is None:
|
||||
return False
|
||||
if not dl.message_count:
|
||||
return False
|
||||
return dl.message_count < session_msg_count - 1
|
||||
|
||||
|
||||
def should_upload_transcript(
|
||||
user_id: str | None, upload_safe: bool
|
||||
) -> bool:
|
||||
def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool:
|
||||
"""Return ``True`` when the caller should upload the final transcript.
|
||||
|
||||
Uploads require a logged-in user (for the storage key) *and* a safe
|
||||
@@ -731,55 +713,137 @@ def should_upload_transcript(
|
||||
return bool(user_id) and upload_safe
|
||||
|
||||
|
||||
def _append_gap_to_builder(
|
||||
gap: list[ChatMessage],
|
||||
builder: TranscriptBuilder,
|
||||
) -> None:
|
||||
"""Append gap messages from chat-db into the TranscriptBuilder.
|
||||
|
||||
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
|
||||
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
|
||||
|
||||
Pre-condition: ``gap`` always starts at a user or assistant boundary
|
||||
(never mid-turn at a ``tool`` role), because ``detect_gap`` enforces
|
||||
``session_messages[wm-1].role == 'assistant'`` before returning a non-empty
|
||||
gap. Any ``tool`` role messages within the gap always follow an assistant
|
||||
entry that already exists in the builder or in the gap itself.
|
||||
"""
|
||||
for msg in gap:
|
||||
if msg.role == "user":
|
||||
builder.append_user(msg.content or "")
|
||||
elif msg.role == "assistant":
|
||||
content_blocks: list[dict] = []
|
||||
if msg.content:
|
||||
content_blocks.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
input_data = util_json.loads(fn.get("arguments", "{}"), fallback={})
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", "") if isinstance(tc, dict) else "",
|
||||
"name": fn.get("name", "unknown"),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
if not content_blocks:
|
||||
# Fallback: ensure every assistant gap message produces an entry
|
||||
# so the builder's entry count matches the gap length.
|
||||
content_blocks.append({"type": "text", "text": ""})
|
||||
builder.append_assistant(content_blocks=content_blocks)
|
||||
elif msg.role == "tool":
|
||||
if msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
else:
|
||||
# Malformed tool message — no tool_call_id to link to an
|
||||
# assistant tool_use block. Skip to avoid an unmatched
|
||||
# tool_result entry in the builder (which would confuse --resume).
|
||||
logger.warning(
|
||||
"[Baseline] Skipping tool gap message with no tool_call_id"
|
||||
)
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
session_msg_count: int,
|
||||
session_messages: list[ChatMessage],
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> bool:
|
||||
"""Download and load the prior transcript into ``transcript_builder``.
|
||||
) -> tuple[bool, "TranscriptDownload | None"]:
|
||||
"""Download and load the prior CLI session into ``transcript_builder``.
|
||||
|
||||
Returns ``True`` when upload is safe at the end of this turn; ``False``
|
||||
when GCS has a *newer* version that we must not overwrite (stale case).
|
||||
|
||||
Upload is suppressed only for **stale** transcripts (GCS watermark >
|
||||
current turn's prefix) and **download errors** (we can't know what GCS
|
||||
holds). Missing and invalid transcripts return ``True`` because there is
|
||||
nothing in GCS worth protecting — uploading is always safe.
|
||||
Returns a tuple of (upload_safe, transcript_download):
|
||||
- ``upload_safe`` is ``True`` when it is safe to upload at the end of this
|
||||
turn. Upload is suppressed only for **download errors** (unknown GCS
|
||||
state) — missing and invalid files return ``True`` because there is
|
||||
nothing in GCS worth protecting against overwriting.
|
||||
- ``transcript_download`` is a ``TranscriptDownload`` with str content
|
||||
(pre-decoded and stripped) when available, or ``None`` when no valid
|
||||
transcript could be loaded. Callers pass this to
|
||||
``extract_context_messages`` to build the LLM context.
|
||||
"""
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
# Unknown GCS state — be conservative and skip upload.
|
||||
return False
|
||||
|
||||
if dl is None:
|
||||
logger.debug("[Baseline] No transcript available — will upload fresh")
|
||||
# Nothing in GCS to protect; allow upload.
|
||||
return True
|
||||
|
||||
if not validate_transcript(dl.content):
|
||||
logger.warning("[Baseline] Downloaded transcript is invalid — will overwrite")
|
||||
# Corrupt file in GCS; uploading a valid one is strictly better.
|
||||
return True
|
||||
|
||||
if is_transcript_stale(dl, session_msg_count):
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
restore = await download_transcript(
|
||||
user_id, session_id, log_prefix="[Baseline]"
|
||||
)
|
||||
# GCS watermark is ahead of this turn — do not overwrite.
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Session restore failed: %s", e)
|
||||
# Unknown GCS state — be conservative, skip upload.
|
||||
return False, None
|
||||
|
||||
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
|
||||
if restore is None:
|
||||
logger.debug("[Baseline] No CLI session available — will upload fresh")
|
||||
# Nothing in GCS to protect; allow upload so the first baseline turn
|
||||
# writes the initial transcript snapshot.
|
||||
return True, None
|
||||
|
||||
content_bytes = restore.content
|
||||
try:
|
||||
raw_str = (
|
||||
content_bytes.decode("utf-8")
|
||||
if isinstance(content_bytes, bytes)
|
||||
else content_bytes
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("[Baseline] CLI session content is not valid UTF-8")
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
stripped = strip_for_upload(raw_str)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning("[Baseline] CLI session content invalid after strip")
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
|
||||
len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str),
|
||||
restore.message_count,
|
||||
)
|
||||
return True
|
||||
|
||||
gap = detect_gap(restore, session_messages)
|
||||
if gap:
|
||||
_append_gap_to_builder(gap, transcript_builder)
|
||||
logger.info(
|
||||
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
|
||||
restore.message_count,
|
||||
len(gap),
|
||||
)
|
||||
|
||||
# Return a str-content version so extract_context_messages receives a
|
||||
# pre-decoded, stripped transcript (avoids redundant decode + strip).
|
||||
# TranscriptDownload.content is typed as bytes | str; we pass str here
|
||||
# to avoid a redundant encode + decode round-trip.
|
||||
str_restore = TranscriptDownload(
|
||||
content=stripped,
|
||||
message_count=restore.message_count,
|
||||
mode=restore.mode,
|
||||
)
|
||||
return True, str_restore
|
||||
|
||||
|
||||
async def _upload_final_transcript(
|
||||
@@ -813,10 +877,10 @@ async def _upload_final_transcript(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
content=content.encode("utf-8"),
|
||||
message_count=session_msg_count,
|
||||
mode="baseline",
|
||||
log_prefix="[Baseline]",
|
||||
skip_strip=True,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(upload_task)
|
||||
@@ -920,15 +984,16 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path.
|
||||
transcript_download: TranscriptDownload | None = None
|
||||
if user_id and len(session.messages) > 1:
|
||||
(
|
||||
transcript_upload_safe,
|
||||
(transcript_upload_safe, transcript_download),
|
||||
(base_system_prompt, understanding),
|
||||
) = await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
session_messages=session.messages,
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
@@ -968,9 +1033,14 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "")
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
# Context path: transcript content (compacted, isCompactSummary preserved) +
|
||||
# gap (DB messages after watermark) + current user turn.
|
||||
# This avoids re-reading the full session history from DB on every turn.
|
||||
# See extract_context_messages() in transcript.py for the shared primitive.
|
||||
prior_context = extract_context_messages(transcript_download, session.messages)
|
||||
messages_for_context = await _compress_session_messages(
|
||||
session.messages, model=active_model
|
||||
prior_context + ([session.messages[-1]] if session.messages else []),
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
# Build OpenAI message list from session history.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Exercises the real helpers in ``baseline/service.py`` that restore,
|
||||
validate, load, append to, backfill, and upload the CLI session.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
@@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_append_gap_to_builder,
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
@@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
"""Build a list of ChatMessage objects matching the given roles."""
|
||||
return [
|
||||
ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)
|
||||
]
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
@@ -73,89 +81,102 @@ class TestResolveBaselineModel:
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
"""``_load_prior_transcript`` wraps the CLI session restore + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert dl.message_count == 2
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
async def test_fills_gap_when_transcript_is_behind(self):
|
||||
"""When transcript covers fewer messages than session, gap is filled from DB."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
|
||||
restore = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="baseline"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_allows_upload(self):
|
||||
"""Nothing in GCS → safe to upload fresh transcript after the turn."""
|
||||
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
upload_safe = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_allows_upload(self):
|
||||
"""Corrupt file in GCS → overwriting with valid data is better."""
|
||||
"""Corrupt file in GCS → overwriting with a valid one is better."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
restore = TranscriptDownload(
|
||||
content=b'{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
upload_safe = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -165,36 +186,39 @@ class TestLoadPriorTranscript:
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
"""When msg_count is 0 (unknown), gap detection is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
restore = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=0,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
session_messages=_make_session_messages(*["user"] * 20),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
@@ -229,7 +253,7 @@ class TestUploadFinalTranscript:
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
assert b"hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
@@ -376,17 +400,19 @@ class TestRoundTrip:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -426,11 +452,11 @@ class TestRoundTrip:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
assert b"new question" in uploaded
|
||||
assert b"new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
@@ -461,36 +487,6 @@ class TestRoundTrip:
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
@@ -512,7 +508,7 @@ class TestShouldUploadTranscript:
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
"""End-to-end: restore → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
@@ -521,27 +517,29 @@ class TestTranscriptLifecycle:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
"""Fresh restore, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=prior.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
# --- 1. Restore & load prior session ---
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -561,10 +559,7 @@ class TestTranscriptLifecycle:
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", upload_safe=covers
|
||||
)
|
||||
is True
|
||||
should_upload_transcript(user_id="user-1", upload_safe=covers) is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
@@ -576,20 +571,21 @@ class TestTranscriptLifecycle:
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
assert b"follow-up question" in uploaded
|
||||
assert b"follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
assert b"user message 0" in uploaded
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
async def test_lifecycle_stale_download_fills_gap(self):
|
||||
"""When transcript covers fewer messages, gap is filled rather than rejected."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
# session has 5 msgs but stored transcript only covers 2 → gap filled.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
@@ -603,20 +599,18 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, _ = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", upload_safe=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
assert covers is True
|
||||
# Gap was filled: 2 from transcript + 2 gap messages
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
@@ -629,14 +623,11 @@ class TestTranscriptLifecycle:
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, upload_safe=True)
|
||||
is False
|
||||
)
|
||||
assert should_upload_transcript(user_id=None, upload_safe=True) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → upload is safe; the turn writes the first snapshot."""
|
||||
"""No prior session → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
@@ -649,18 +640,117 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
upload_safe = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
session_messages=_make_session_messages("user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# Nothing in GCS → upload is safe so the first baseline turn
|
||||
# can write the initial snapshot.
|
||||
# can write the initial transcript snapshot.
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", upload_safe=upload_safe
|
||||
)
|
||||
should_upload_transcript(user_id="user-1", upload_safe=upload_safe)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _append_gap_to_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendGapToBuilder:
|
||||
"""``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries."""
|
||||
|
||||
def test_user_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
assert builder.last_entry_type == "user"
|
||||
|
||||
def test_assistant_text_message_appended(self):
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="answer"),
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "answer" in builder.to_jsonl()
|
||||
|
||||
def test_assistant_with_tool_calls_appended(self):
|
||||
"""Assistant tool_calls are recorded as tool_use blocks in the transcript."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {"name": "my_tool", "arguments": '{"key":"val"}'},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "tool_use" in jsonl
|
||||
assert "my_tool" in jsonl
|
||||
assert "tc-1" in jsonl
|
||||
|
||||
def test_assistant_invalid_json_args_uses_empty_dict(self):
|
||||
"""Malformed JSON in tool_call arguments falls back to {}."""
|
||||
builder = TranscriptBuilder()
|
||||
tool_call = {
|
||||
"id": "tc-bad",
|
||||
"type": "function",
|
||||
"function": {"name": "bad_tool", "arguments": "not-json"},
|
||||
}
|
||||
msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
def test_assistant_empty_content_and_no_tools_uses_fallback(self):
|
||||
"""Assistant with no content and no tool_calls gets a fallback empty text block."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="assistant", content=None)]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "text" in jsonl
|
||||
|
||||
def test_tool_role_with_tool_call_id_appended(self):
|
||||
"""Tool result messages are appended when tool_call_id is set."""
|
||||
builder = TranscriptBuilder()
|
||||
# Need a preceding assistant tool_use entry
|
||||
builder.append_user("use tool")
|
||||
builder.append_assistant(
|
||||
content_blocks=[
|
||||
{"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}}
|
||||
]
|
||||
)
|
||||
msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 3
|
||||
assert "tool_result" in builder.to_jsonl()
|
||||
|
||||
def test_tool_role_without_tool_call_id_skipped(self):
|
||||
"""Tool messages without tool_call_id are silently skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 0
|
||||
|
||||
def test_tool_call_missing_function_key_uses_unknown_name(self):
|
||||
"""A tool_call dict with no 'function' key uses 'unknown' as the tool name."""
|
||||
builder = TranscriptBuilder()
|
||||
# Tool call dict exists but 'function' sub-dict is missing entirely
|
||||
msgs = [
|
||||
ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}])
|
||||
]
|
||||
_append_gap_to_builder(msgs, builder)
|
||||
assert builder.entry_count == 1
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "unknown" in jsonl
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
# Allowed base directory for the Read tool. Public so service.py can use it
|
||||
# for sweep operations without depending on a private implementation detail.
|
||||
# Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's
|
||||
# _projects_base() function.
|
||||
# projects_base() function.
|
||||
_config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects"))
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
|
||||
@@ -8,7 +8,7 @@ Cross-mode transcript flow
|
||||
==========================
|
||||
|
||||
Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking
|
||||
mode) read and write the same JSONL transcript store via
|
||||
mode) read and write the same CLI session store via
|
||||
``backend.copilot.transcript.upload_transcript`` /
|
||||
``download_transcript``.
|
||||
|
||||
@@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_baseline_loads_sdk_transcript(self):
|
||||
"""Scenario S: SDK-written transcript is accepted by baseline's load helper."""
|
||||
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch:
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Baseline session now has those 2 SDK messages + 1 new baseline message.
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3, # 2 SDK + 1 new baseline
|
||||
session_messages=[
|
||||
ChatMessage(role="user", content="sdk-question"),
|
||||
ChatMessage(role="assistant", content="sdk-answer"),
|
||||
ChatMessage(role="user", content="baseline-question"),
|
||||
],
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Transcript is valid and covers the prefix.
|
||||
# CLI session is valid and covers the prefix.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
assert baseline_builder.entry_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_s_stale_sdk_transcript_not_loaded(self):
|
||||
"""Scenario S (stale): SDK transcript is stale — baseline does not load it.
|
||||
"""Scenario S (stale): SDK CLI session is stale — baseline does not load it.
|
||||
|
||||
If SDK mode produced more turns than the transcript captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale transcript
|
||||
If SDK mode produced more turns than the session captured (e.g.
|
||||
upload failed on one turn), the baseline rejects the stale session
|
||||
to avoid injecting an incomplete history.
|
||||
"""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch:
|
||||
)
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Transcript covers only 2 messages but session has 10 (many SDK turns).
|
||||
download = TranscriptDownload(content=sdk_transcript, message_count=2)
|
||||
# Session covers only 2 messages but session has 10 (many SDK turns).
|
||||
# With watermark=2 and 10 total messages, detect_gap will fill the gap
|
||||
# by appending messages 2..8 (positions 2 to total-2).
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
# Build a session with 10 alternating user/assistant messages + current user
|
||||
session_messages = [
|
||||
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
covers, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=session_messages,
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Stale transcript must be rejected.
|
||||
assert covers is False
|
||||
assert baseline_builder.is_empty
|
||||
# With gap filling, covers is True and gap messages are appended.
|
||||
assert covers is True
|
||||
assert dl is not None
|
||||
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
|
||||
assert baseline_builder.entry_count == 9
|
||||
|
||||
@@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.transcript import (
|
||||
TranscriptDownload,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
@@ -999,14 +1000,15 @@ def _make_sdk_patches(
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(content=original_transcript, message_count=2),
|
||||
return_value=TranscriptDownload(
|
||||
content=original_transcript.encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
),
|
||||
),
|
||||
),
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=True),
|
||||
),
|
||||
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.validate_transcript", dict(return_value=True)),
|
||||
(
|
||||
f"{_SVC}.compact_transcript",
|
||||
@@ -1037,7 +1039,6 @@ def _make_sdk_patches(
|
||||
claude_agent_fallback_model=None,
|
||||
),
|
||||
),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)),
|
||||
]
|
||||
|
||||
@@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
compacted_transcript=None,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
# Override restore_cli_session to return False (CLI native session unavailable)
|
||||
# Override download_transcript to return None (CLI native session unavailable)
|
||||
patches = [
|
||||
(
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=False),
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(new_callable=AsyncMock, return_value=None),
|
||||
)
|
||||
if p[0] == f"{_SVC}.restore_cli_session"
|
||||
if p[0] == f"{_SVC}.download_transcript"
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
@@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# captured_options holds {"options": ClaudeAgentOptions}, so check
|
||||
# the attribute directly rather than dict keys.
|
||||
assert not getattr(captured_options.get("options"), "resume", None), (
|
||||
f"--resume was set even though restore_cli_session returned False: "
|
||||
f"--resume was set even though download_transcript returned None: "
|
||||
f"{captured_options}"
|
||||
)
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -365,7 +365,7 @@ def create_security_hooks(
|
||||
trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50)
|
||||
# Sanitize untrusted input: strip control chars for logging AND
|
||||
# for the value passed downstream. read_compacted_entries()
|
||||
# validates against _projects_base() as defence-in-depth, but
|
||||
# validates against projects_base() as defence-in-depth, but
|
||||
# sanitizing here prevents log injection and rejects obviously
|
||||
# malformed paths early.
|
||||
transcript_path = _sanitize(
|
||||
|
||||
@@ -16,6 +16,7 @@ import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field as dataclass_field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from ..transcript import (
|
||||
_run_compression,
|
||||
TranscriptDownload,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
upload_cli_session,
|
||||
strip_for_upload,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
@@ -849,6 +853,181 @@ def _make_sdk_cwd(session_id: str) -> str:
|
||||
return cwd
|
||||
|
||||
|
||||
def _write_cli_session_to_disk(
|
||||
content: bytes,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> bool:
|
||||
"""Write downloaded CLI session bytes to disk so the CLI can --resume.
|
||||
|
||||
Returns True on success, False if the path is invalid or the write fails.
|
||||
Path-traversal guard: rejects paths outside the CLI projects base.
|
||||
"""
|
||||
session_file = cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
_pbase = projects_base()
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session restore path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return False
|
||||
try:
|
||||
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
||||
Path(real_path).write_bytes(content)
|
||||
logger.info(
|
||||
"%s Wrote CLI session to disk (%dB) for --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"%s Failed to write CLI session file %s: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _read_cli_session_from_disk(
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> bytes | None:
|
||||
"""Read the CLI session JSONL file from disk after the SDK turn.
|
||||
|
||||
Returns the file bytes, or None if the file is missing, outside the
|
||||
projects base, or unreadable.
|
||||
Path-traversal guard: rejects paths outside the CLI projects base.
|
||||
"""
|
||||
session_file = cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
_pbase = projects_base()
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return None
|
||||
try:
|
||||
raw_bytes = Path(real_path).read_bytes()
|
||||
except FileNotFoundError:
|
||||
logger.debug(
|
||||
"%s CLI session file not found, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return None
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"%s Failed to read CLI session file %s: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
# Strip stale thinking blocks and metadata entries before uploading.
|
||||
# Thinking blocks from non-last turns can be massive; keeping them causes
|
||||
# the CLI to auto-compact its session when the context window fills up,
|
||||
# silently losing conversation history.
|
||||
try:
|
||||
raw_text = raw_bytes.decode("utf-8")
|
||||
stripped_text = strip_for_upload(raw_text)
|
||||
stripped_bytes = stripped_text.encode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix)
|
||||
return raw_bytes
|
||||
except (OSError, ValueError) as e:
|
||||
# OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip.
|
||||
# Other unexpected exceptions are not silently swallowed here so they propagate
|
||||
# to the outer OSError handler and are logged with exc_info.
|
||||
logger.warning(
|
||||
"%s Failed to strip CLI session, uploading raw: %s", log_prefix, e
|
||||
)
|
||||
return raw_bytes
|
||||
|
||||
if len(stripped_bytes) < len(raw_bytes):
|
||||
# Write back locally so same-pod turns also benefit.
|
||||
try:
|
||||
Path(real_path).write_bytes(stripped_bytes)
|
||||
logger.info(
|
||||
"%s Stripped CLI session: %dB → %dB",
|
||||
log_prefix,
|
||||
len(raw_bytes),
|
||||
len(stripped_bytes),
|
||||
)
|
||||
except OSError as e:
|
||||
# write_bytes failed — stripped content is still valid for GCS upload even
|
||||
# though the local write-back failed (same-pod optimization silently skipped).
|
||||
logger.warning(
|
||||
"%s Failed to write back stripped CLI session: %s",
|
||||
log_prefix,
|
||||
e.strerror or str(e),
|
||||
)
|
||||
return stripped_bytes
|
||||
|
||||
|
||||
def _process_cli_restore(
|
||||
cli_restore: TranscriptDownload,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate and write a restored CLI session to disk.
|
||||
|
||||
Decodes bytes → UTF-8, strips progress entries and stale thinking blocks,
|
||||
validates the result, then writes the stripped content to disk so the CLI
|
||||
can ``--resume`` from it.
|
||||
|
||||
Returns ``(stripped_content, success)`` where ``success=False`` means the
|
||||
content was invalid or the disk write failed (caller should skip --resume).
|
||||
"""
|
||||
try:
|
||||
raw_bytes = cli_restore.content
|
||||
raw_str = (
|
||||
raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
logger.warning(
|
||||
"%s CLI session content is not valid UTF-8, skipping", log_prefix
|
||||
)
|
||||
return "", False
|
||||
|
||||
stripped = strip_for_upload(raw_str)
|
||||
is_valid = validate_transcript(stripped)
|
||||
# Use len(raw_str) rather than len(cli_restore.content) so the unit is always
|
||||
# characters (raw_str is always str at this point regardless of input type).
|
||||
# lines_stripped = original lines minus remaining lines after stripping.
|
||||
_original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0
|
||||
_remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0
|
||||
logger.info(
|
||||
"%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(raw_str),
|
||||
_original_lines - _remaining_lines,
|
||||
cli_restore.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"%s CLI session content invalid after strip — running without --resume",
|
||||
log_prefix,
|
||||
)
|
||||
return "", False
|
||||
|
||||
stripped_bytes = stripped.encode("utf-8")
|
||||
if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix):
|
||||
return "", False
|
||||
|
||||
return stripped, True
|
||||
|
||||
|
||||
async def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK session artifacts for a specific working directory.
|
||||
|
||||
@@ -922,8 +1101,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
result.append(block)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SDK] Unknown content block type: {type(block).__name__}. "
|
||||
f"This may indicate a new SDK version with additional block types."
|
||||
"[SDK] Unknown content block type: %s."
|
||||
" This may indicate a new SDK version with additional block types.",
|
||||
type(block).__name__,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -978,10 +1158,11 @@ async def _compress_messages(
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||
f"{result.token_count} tokens "
|
||||
f"({result.messages_summarized} summarized, "
|
||||
f"{result.messages_dropped} dropped)"
|
||||
"[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
result.messages_dropped,
|
||||
)
|
||||
# Convert compressed dicts back to ChatMessages
|
||||
return [
|
||||
@@ -1048,11 +1229,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str:
|
||||
)
|
||||
if blocks:
|
||||
builder.append_assistant(blocks)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
elif msg.role == "tool":
|
||||
if msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
else:
|
||||
# Malformed tool message — no tool_call_id to link to an
|
||||
# assistant tool_use block. Skip to avoid an unmatched
|
||||
# tool_result entry in the builder (which would confuse --resume).
|
||||
logger.warning("[SDK] Skipping tool gap message with no tool_call_id")
|
||||
return builder.to_jsonl()
|
||||
|
||||
|
||||
@@ -1098,6 +1285,7 @@ async def _build_query_message(
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
target_tokens: int | None = None,
|
||||
prior_messages: "list[ChatMessage] | None" = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
@@ -1203,15 +1391,16 @@ async def _build_query_message(
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
source = prior_messages if prior_messages is not None else prior
|
||||
logger.warning(
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing"
|
||||
" full session history (pod affinity issue or first turn after"
|
||||
" restore failure); target_tokens=%s",
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing context "
|
||||
"(source=%s, target_tokens=%s)",
|
||||
session_id[:8],
|
||||
msg_count,
|
||||
"transcript+gap" if prior_messages is not None else "full-db",
|
||||
target_tokens,
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(prior, target_tokens)
|
||||
compressed, was_compressed = await _compress_messages(source, target_tokens)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
logger.info(
|
||||
@@ -1228,7 +1417,7 @@ async def _build_query_message(
|
||||
"[SDK] [%s] Fallback context empty after compression"
|
||||
" (%d messages) — sending message without history",
|
||||
session_id[:8],
|
||||
len(prior),
|
||||
len(source),
|
||||
)
|
||||
|
||||
return current_message, False
|
||||
@@ -2233,6 +2422,163 @@ async def _seed_transcript(
|
||||
return _seeded, True, len(_prior)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RestoreResult:
|
||||
"""Return value from ``_restore_cli_session_for_turn``."""
|
||||
|
||||
transcript_content: str = ""
|
||||
transcript_covers_prefix: bool = True
|
||||
use_resume: bool = False
|
||||
resume_file: str | None = None
|
||||
transcript_msg_count: int = 0
|
||||
baseline_download: "TranscriptDownload | None" = None
|
||||
context_messages: "list[ChatMessage] | None" = None
|
||||
|
||||
|
||||
async def _restore_cli_session_for_turn(
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: "ChatSession",
|
||||
sdk_cwd: str,
|
||||
transcript_builder: "TranscriptBuilder",
|
||||
log_prefix: str,
|
||||
) -> _RestoreResult:
|
||||
"""Download, validate and restore a CLI session for ``--resume`` on this turn.
|
||||
|
||||
Performs a single GCS round-trip to fetch the session bytes + message_count
|
||||
watermark. Falls back to DB-message reconstruction when GCS has no session
|
||||
(first turn or upload missed).
|
||||
|
||||
Returns a ``_RestoreResult`` with all transcript-related state ready for the
|
||||
caller to merge into its local variables.
|
||||
"""
|
||||
result = _RestoreResult()
|
||||
|
||||
if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1):
|
||||
return result
|
||||
|
||||
try:
|
||||
cli_restore = await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as restore_err:
|
||||
logger.warning(
|
||||
"%s CLI session restore failed, continuing without --resume: %s",
|
||||
log_prefix,
|
||||
restore_err,
|
||||
)
|
||||
cli_restore = None
|
||||
|
||||
# Only attempt --resume for SDK-written transcripts.
|
||||
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
|
||||
# stripped fields) that may not be valid for --resume.
|
||||
if cli_restore is not None and cli_restore.mode != "sdk":
|
||||
logger.info(
|
||||
"%s Transcript written by mode=%r — skipping --resume, "
|
||||
"will use transcript content + gap for context",
|
||||
log_prefix,
|
||||
cli_restore.mode,
|
||||
)
|
||||
result.baseline_download = cli_restore # keep for extract_context_messages
|
||||
cli_restore = None
|
||||
|
||||
# Validate, strip, and write to disk — delegate to helper to reduce
|
||||
# function complexity. Writing an invalid/corrupt file to disk then
|
||||
# falling back to "no --resume" would cause the CLI to fail with
|
||||
# "Session ID already in use" because the file exists at the expected
|
||||
# session path, so we validate BEFORE any disk write.
|
||||
stripped = ""
|
||||
if cli_restore is not None and sdk_cwd:
|
||||
stripped, ok = _process_cli_restore(
|
||||
cli_restore, sdk_cwd, session_id, log_prefix
|
||||
)
|
||||
if not ok:
|
||||
result.transcript_covers_prefix = False
|
||||
cli_restore = None
|
||||
|
||||
if cli_restore is None and sdk_cwd:
|
||||
# Validation failed or GCS returned no session. Delete any
|
||||
# existing local session file so the CLI doesn't reject the
|
||||
# session_id with "Session ID already in use". T1 may have
|
||||
# left a valid file at this path; we clear it so the fallback
|
||||
# path (session_id= without --resume) can create a new session.
|
||||
_stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id))
|
||||
if Path(_stale_path).exists() and _stale_path.startswith(
|
||||
projects_base() + os.sep
|
||||
):
|
||||
try:
|
||||
Path(_stale_path).unlink()
|
||||
logger.debug(
|
||||
"%s Removed stale local CLI session file for clean fallback",
|
||||
log_prefix,
|
||||
)
|
||||
except OSError as _unlink_err:
|
||||
logger.debug(
|
||||
"%s Failed to remove stale local session file: %s",
|
||||
log_prefix,
|
||||
_unlink_err,
|
||||
)
|
||||
|
||||
if cli_restore is not None:
|
||||
result.transcript_content = stripped
|
||||
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
|
||||
result.use_resume = True
|
||||
result.resume_file = session_id
|
||||
result.transcript_msg_count = cli_restore.message_count
|
||||
return result
|
||||
|
||||
# No valid --resume source (mode="baseline" or no GCS file).
|
||||
# Build context from transcript content + gap, falling back to full DB.
|
||||
# extract_context_messages handles both: non-None baseline_download uses
|
||||
# the compacted transcript + gap; None falls back to all prior DB messages.
|
||||
context_msgs = extract_context_messages(result.baseline_download, session.messages)
|
||||
result.context_messages = context_msgs
|
||||
result.transcript_msg_count = (
|
||||
result.baseline_download.message_count
|
||||
if result.baseline_download is not None
|
||||
and result.baseline_download.message_count > 0
|
||||
else len(session.messages) - 1
|
||||
)
|
||||
result.transcript_covers_prefix = True
|
||||
logger.info(
|
||||
"%s Context built from %s: %d messages (transcript watermark=%d, "
|
||||
"will inject as <conversation_history>)",
|
||||
log_prefix,
|
||||
(
|
||||
"baseline transcript + gap"
|
||||
if result.baseline_download is not None
|
||||
else "DB fallback"
|
||||
),
|
||||
len(context_msgs),
|
||||
result.transcript_msg_count,
|
||||
)
|
||||
|
||||
# Load baseline transcript content into builder so the upload path has accurate state.
|
||||
# Also sets result.transcript_content so the _seed_transcript guard in the caller
|
||||
# (``not transcript_content``) does not overwrite this builder state with a DB
|
||||
# reconstruction — which would duplicate entries since load_previous appends.
|
||||
if result.baseline_download is not None:
|
||||
try:
|
||||
raw_for_builder = result.baseline_download.content
|
||||
if isinstance(raw_for_builder, bytes):
|
||||
raw_for_builder = raw_for_builder.decode("utf-8")
|
||||
stripped = strip_for_upload(raw_for_builder)
|
||||
if validate_transcript(stripped):
|
||||
transcript_builder.load_previous(stripped, log_prefix=log_prefix)
|
||||
result.transcript_content = stripped
|
||||
except (UnicodeDecodeError, ValueError, OSError) as _load_err:
|
||||
# UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in
|
||||
# strip_for_upload; OSError: encode/decode I/O failure. Unexpected
|
||||
# exceptions propagate so programming errors are not silently masked.
|
||||
logger.debug(
|
||||
"%s Could not load baseline transcript into builder: %s",
|
||||
log_prefix,
|
||||
_load_err,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -2427,28 +2773,9 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
return sandbox
|
||||
|
||||
async def _fetch_transcript():
|
||||
"""Download transcript for --resume if applicable."""
|
||||
if not (
|
||||
config.claude_agent_use_resume and user_id and len(session.messages) > 1
|
||||
):
|
||||
return None
|
||||
try:
|
||||
return await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"%s Transcript download failed, continuing without --resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
return None
|
||||
|
||||
e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather(
|
||||
e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather(
|
||||
_setup_e2b(),
|
||||
_build_system_prompt(user_id if not has_history else None),
|
||||
_fetch_transcript(),
|
||||
)
|
||||
|
||||
use_e2b = e2b_sandbox is not None
|
||||
@@ -2473,95 +2800,17 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
warm_ctx = await fetch_warm_context(user_id, message or "") or ""
|
||||
|
||||
# Process transcript download result and restore CLI native session.
|
||||
# The CLI native session file (uploaded after each turn) is the
|
||||
# source of truth for --resume. Our custom JSONL (TranscriptEntry)
|
||||
# is loaded into the builder for future upload_transcript calls.
|
||||
transcript_msg_count = 0
|
||||
if dl:
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
dl.message_count,
|
||||
is_valid,
|
||||
)
|
||||
if is_valid:
|
||||
# Load previous FULL context into builder for state tracking.
|
||||
transcript_content = dl.content
|
||||
transcript_builder.load_previous(dl.content, log_prefix=log_prefix)
|
||||
# Restore CLI's native session file so --resume session_id works.
|
||||
# Falls back gracefully if not available (first turn or upload missed).
|
||||
# user_id is guaranteed non-None here: _fetch_transcript only sets dl
|
||||
# when `config.claude_agent_use_resume and user_id` is truthy.
|
||||
cli_restored = user_id is not None and await restore_cli_session(
|
||||
user_id, session_id, sdk_cwd, log_prefix=log_prefix
|
||||
)
|
||||
if cli_restored:
|
||||
use_resume = True
|
||||
resume_file = session_id # CLI --resume expects UUID, not file path
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s Using --resume %s (%dB transcript, msg_count=%d)",
|
||||
log_prefix,
|
||||
session_id[:8],
|
||||
len(dl.content),
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
# Builder loaded but CLI native session not available.
|
||||
# --resume will not be used this turn; upload after turn
|
||||
# will seed the native session for the next turn.
|
||||
#
|
||||
# Still record transcript_msg_count so _build_query_message
|
||||
# can use the transcript-aware gap path (inject only new
|
||||
# messages since the transcript end) instead of compressing
|
||||
# the full DB history. This avoids prompt-too-long on
|
||||
# large sessions where the CLI session is temporarily
|
||||
# unavailable (e.g. mixed-version rolling deployment).
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s CLI session not restored — running without"
|
||||
" --resume this turn (transcript_msg_count=%d for"
|
||||
" gap-aware fallback)",
|
||||
log_prefix,
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
transcript_covers_prefix = False
|
||||
elif config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
# No transcript in storage — reconstruct from DB messages as a
|
||||
# last-resort fallback (e.g., first turn after a crash or transition).
|
||||
# This path loses tool call IDs and structural fidelity but prevents
|
||||
# a completely context-free response for established sessions.
|
||||
prior = session.messages[:-1]
|
||||
reconstructed = _session_messages_to_transcript(prior)
|
||||
if reconstructed:
|
||||
# Populate builder only; no --resume since there is no CLI
|
||||
# native session to restore. The transcript builder state is
|
||||
# still useful for the upload that seeds future native sessions.
|
||||
transcript_content = reconstructed
|
||||
transcript_builder.load_previous(reconstructed, log_prefix=log_prefix)
|
||||
transcript_msg_count = len(prior)
|
||||
transcript_covers_prefix = True
|
||||
logger.info(
|
||||
"%s Reconstructed transcript from %d session messages "
|
||||
"(no CLI native session — running without --resume this turn)",
|
||||
log_prefix,
|
||||
len(prior),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s No transcript available and reconstruction produced empty"
|
||||
" output (%d messages in session)",
|
||||
log_prefix,
|
||||
len(session.messages),
|
||||
)
|
||||
transcript_covers_prefix = False
|
||||
# Restore CLI session — single GCS round-trip covers both --resume and builder state.
|
||||
# message_count watermark lives in the companion .meta.json alongside the session file.
|
||||
_restore = await _restore_cli_session_for_turn(
|
||||
user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix
|
||||
)
|
||||
transcript_content = _restore.transcript_content
|
||||
transcript_covers_prefix = _restore.transcript_covers_prefix
|
||||
use_resume = _restore.use_resume
|
||||
resume_file = _restore.resume_file
|
||||
transcript_msg_count = _restore.transcript_msg_count
|
||||
restore_context_messages = _restore.context_messages
|
||||
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
|
||||
@@ -2680,14 +2929,14 @@ async def stream_chat_completion_sdk(
|
||||
else:
|
||||
# Set session_id whenever NOT resuming so the CLI writes the
|
||||
# native session file to a predictable path for
|
||||
# upload_cli_session() after the turn. This covers:
|
||||
# upload_transcript() after the turn. This covers:
|
||||
# • T1 fresh: no prior history, first SDK turn.
|
||||
# • Mode-switch T1: has_history=True (prior baseline turns in
|
||||
# DB) but no CLI session file was ever uploaded — the CLI has
|
||||
# never been invoked with this session_id before.
|
||||
# • T2+ without --resume (restore failed): no session file was
|
||||
# restored to local storage (restore_cli_session returned
|
||||
# False), so no conflict with an existing file.
|
||||
# restored to local storage (download_transcript returned
|
||||
# None), so no conflict with an existing file.
|
||||
# When --resume is active the session_id is already implied by
|
||||
# the resume file; passing it again would be rejected by the CLI.
|
||||
sdk_options_kwargs["session_id"] = session_id
|
||||
@@ -2780,6 +3029,7 @@ async def stream_chat_completion_sdk(
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
session_id,
|
||||
prior_messages=restore_context_messages,
|
||||
)
|
||||
# If files are attached, prepare them: images become vision
|
||||
# content blocks in the user message, other files go to sdk_cwd.
|
||||
@@ -2909,7 +3159,7 @@ async def stream_chat_completion_sdk(
|
||||
elif "session_id" in sdk_options_kwargs:
|
||||
# Initial invocation used session_id (T1 or mode-switch
|
||||
# T1): keep it so the CLI writes the session file to the
|
||||
# predictable path for upload_cli_session(). Storage is
|
||||
# predictable path for upload_transcript(). Storage is
|
||||
# ephemeral per invocation, so no "Session ID already in
|
||||
# use" conflict occurs — no prior file was restored.
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
@@ -2932,6 +3182,10 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt, cross_user_cache=_cross_user_retry
|
||||
)
|
||||
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
|
||||
# Retry intentionally omits prior_messages (transcript+gap context) and
|
||||
# falls back to full session.messages[:-1] from DB — the authoritative
|
||||
# source. transcript+gap is an optimisation for the first attempt only;
|
||||
# on retry the extra overhead of full-DB context is acceptable.
|
||||
state.query_message, state.was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
session,
|
||||
@@ -3367,86 +3621,23 @@ async def stream_chat_completion_sdk(
|
||||
_background_tasks.add(_ingest_task)
|
||||
_ingest_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# --- Upload transcript for next-turn --resume ---
|
||||
# TranscriptBuilder is the single source of truth. It mirrors the
|
||||
# CLI's active context: on compaction, replace_entries() syncs it
|
||||
# with the compacted session file. No CLI file read needed here.
|
||||
if skip_transcript_upload:
|
||||
logger.warning(
|
||||
"%s Skipping transcript upload — transcript was dropped "
|
||||
"during prompt-too-long recovery",
|
||||
log_prefix,
|
||||
)
|
||||
elif (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and session is not None
|
||||
and state is not None
|
||||
):
|
||||
try:
|
||||
transcript_upload_content = state.transcript_builder.to_jsonl()
|
||||
entry_count = state.transcript_builder.entry_count
|
||||
|
||||
if not transcript_upload_content:
|
||||
logger.warning(
|
||||
"%s No transcript to upload (builder empty)", log_prefix
|
||||
)
|
||||
elif not validate_transcript(transcript_upload_content):
|
||||
logger.warning(
|
||||
"%s Transcript invalid, skipping upload (entries=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
)
|
||||
elif not transcript_covers_prefix:
|
||||
logger.warning(
|
||||
"%s Skipping transcript upload — builder does not "
|
||||
"cover full session prefix (entries=%d, session=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
len(session.messages),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s Uploading transcript (entries=%d, bytes=%d)",
|
||||
log_prefix,
|
||||
entry_count,
|
||||
len(transcript_upload_content),
|
||||
)
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=transcript_upload_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
except Exception as upload_err:
|
||||
logger.error(
|
||||
"%s Transcript upload failed in finally: %s",
|
||||
log_prefix,
|
||||
upload_err,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# --- Upload CLI native session file for cross-pod --resume ---
|
||||
# The CLI writes its native session JSONL after each turn completes.
|
||||
# Uploading it here enables --resume on any pod (no pod affinity needed).
|
||||
# Runs after upload_transcript so both are available for the next turn.
|
||||
# asyncio.shield: same pattern as upload_transcript above — if the
|
||||
# outer finally-block coroutine is cancelled while awaiting shield,
|
||||
# the CancelledError propagates (BaseException, not caught by
|
||||
# `except Exception`) letting the caller handle cancellation, while
|
||||
# the shielded inner coroutine continues running to completion so the
|
||||
# upload is not lost. This is intentional and matches the pattern
|
||||
# used for upload_transcript immediately above.
|
||||
# The companion .meta.json carries the message_count watermark and mode
|
||||
# so the next turn can restore both --resume context and gap-fill state
|
||||
# in a single GCS round-trip via download_transcript().
|
||||
# asyncio.shield: if the outer finally-block coroutine is cancelled
|
||||
# while awaiting shield, the CancelledError propagates (BaseException,
|
||||
# not caught by `except Exception`) letting the caller handle
|
||||
# cancellation, while the shielded inner coroutine continues running
|
||||
# to completion so the upload is not lost.
|
||||
#
|
||||
# NOTE: upload is attempted regardless of state.use_resume — even when
|
||||
# this turn ran without --resume (restore failed or first T2+ on a new
|
||||
# pod), the T1 session file at the expected path may still be present
|
||||
# and should be re-uploaded so the next turn can resume from it.
|
||||
# upload_cli_session silently skips when the file is absent, so this is
|
||||
# always safe.
|
||||
# _read_cli_session_from_disk returns None when the file is absent, so
|
||||
# this is always safe.
|
||||
#
|
||||
# Intentionally NOT gated on skip_transcript_upload: that flag is set
|
||||
# when our custom JSONL transcript is dropped (transcript_lost=True on
|
||||
@@ -3472,14 +3663,36 @@ async def stream_chat_completion_sdk(
|
||||
skip_transcript_upload,
|
||||
)
|
||||
try:
|
||||
await asyncio.shield(
|
||||
upload_cli_session(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
# Read the CLI's native session file from disk (written by the CLI
|
||||
# after the turn), then upload the bytes to GCS.
|
||||
_cli_content = _read_cli_session_from_disk(
|
||||
sdk_cwd, session_id, log_prefix
|
||||
)
|
||||
if _cli_content:
|
||||
# Watermark = number of DB messages this transcript covers.
|
||||
# len(session.messages) is accurate: the CLI session file
|
||||
# was just written after the turn completed, so it covers
|
||||
# all messages through this turn. Any gap from a prior
|
||||
# missed upload was already detected by detect_gap and
|
||||
# injected as context, so the model has the full history.
|
||||
#
|
||||
# Previously this used _final_tmsg_count + 2, which
|
||||
# under-counted for tool-use turns (delta = 2 + 2*N_tool_calls),
|
||||
# causing persistent spurious gap-fills on every subsequent turn.
|
||||
# That concern was addressed by the inflated-watermark fix
|
||||
# (using the GCS watermark as the anchor for gap detection),
|
||||
# which makes len(session.messages) safe to use here.
|
||||
_jsonl_covered = len(session.messages)
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_cli_content,
|
||||
message_count=_jsonl_covered,
|
||||
mode="sdk",
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
except Exception as cli_upload_err:
|
||||
logger.warning(
|
||||
"%s CLI session upload failed in finally: %s",
|
||||
|
||||
@@ -22,6 +22,7 @@ from .service import (
|
||||
_iter_sdk_messages,
|
||||
_normalize_model_name,
|
||||
_reduce_context,
|
||||
_restore_cli_session_for_turn,
|
||||
_TokenUsage,
|
||||
)
|
||||
|
||||
@@ -615,3 +616,340 @@ class TestSdkSessionIdSelection:
|
||||
)
|
||||
assert retry.get("resume") == self.SESSION_ID
|
||||
assert "session_id" not in retry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _restore_cli_session_for_turn — mode check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestoreCliSessionModeCheck:
|
||||
"""SDK skips --resume when the transcript was written by the baseline mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path):
|
||||
"""A transcript with mode='baseline' must not be used as the --resume source.
|
||||
|
||||
The mode check discards the GCS baseline content and falls back to DB
|
||||
reconstruction from session.messages instead.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hello-unique-marker"),
|
||||
ChatMessage(role="assistant", content="world-unique-marker"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
# Baseline content with a sentinel that must NOT appear in the final transcript
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n',
|
||||
message_count=1,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
download_mock = AsyncMock(return_value=baseline_restore)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=download_mock,
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
# download_transcript was called (attempted GCS restore)
|
||||
download_mock.assert_awaited_once()
|
||||
# use_resume must be False — baseline transcripts cannot be used with --resume
|
||||
assert result.use_resume is False
|
||||
# context_messages must be populated — new behaviour uses transcript content + gap
|
||||
# instead of full DB reconstruction.
|
||||
assert result.context_messages is not None
|
||||
# The baseline transcript has 1 user message (BASELINE_SENTINEL).
|
||||
# Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns [].
|
||||
# Result: 1 message from transcript, no gap.
|
||||
assert len(result.context_messages) == 1
|
||||
assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_mode_transcript_allows_resume(self, tmp_path):
|
||||
"""A valid SDK-written transcript is accepted for --resume."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="hello"),
|
||||
ChatMessage(role="user", content="follow up"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
sdk_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="sdk",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=sdk_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_context_messages_from_transcript_content(
|
||||
self, tmp_path
|
||||
):
|
||||
"""mode='baseline' → context_messages populated from transcript content + gap.
|
||||
|
||||
When a baseline-mode transcript exists, extract_context_messages converts
|
||||
the JSONL content to ChatMessage objects and returns them in context_messages.
|
||||
use_resume must remain False.
|
||||
"""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Build a minimal valid JSONL transcript with 2 messages
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# Transcript content has 2 messages, no gap (watermark=2, session prior=2)
|
||||
assert len(result.context_messages) == 2
|
||||
assert result.context_messages[0].role == "user"
|
||||
assert result.context_messages[1].role == "assistant"
|
||||
assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "")
|
||||
# transcript_content must be non-empty so the _seed_transcript guard in
|
||||
# stream_chat_completion_sdk skips DB reconstruction (which would duplicate
|
||||
# builder entries since load_previous appends).
|
||||
assert result.transcript_content != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path):
|
||||
"""mode='baseline' + gap → context_messages includes transcript msgs and gap."""
|
||||
import json as stdlib_json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
# Transcript covers only 2 messages; session has 4 prior + current turn
|
||||
lines = [
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": "uid-0",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "TRANSCRIPT_USER_0"},
|
||||
}
|
||||
),
|
||||
stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-0",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
content = ("\n".join(lines) + "\n").encode("utf-8")
|
||||
|
||||
session = ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=[
|
||||
ChatMessage(role="user", content="DB_USER_0"),
|
||||
ChatMessage(role="assistant", content="DB_ASSISTANT_1"),
|
||||
ChatMessage(role="user", content="GAP_USER_2"),
|
||||
ChatMessage(role="assistant", content="GAP_ASSISTANT_3"),
|
||||
ChatMessage(role="user", content="current turn"),
|
||||
],
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
builder = TranscriptBuilder()
|
||||
baseline_restore = TranscriptDownload(
|
||||
content=content,
|
||||
message_count=2, # watermark=2; session has 4 prior → gap of 2
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
import backend.copilot.sdk.service as _svc_mod
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.download_transcript",
|
||||
new=AsyncMock(return_value=baseline_restore),
|
||||
),
|
||||
patch.object(_svc_mod.config, "claude_agent_use_resume", True),
|
||||
):
|
||||
result = await _restore_cli_session_for_turn(
|
||||
user_id="user-1",
|
||||
session_id="test-session",
|
||||
session=session,
|
||||
sdk_cwd=str(tmp_path),
|
||||
transcript_builder=builder,
|
||||
log_prefix="[Test]",
|
||||
)
|
||||
|
||||
assert result.use_resume is False
|
||||
assert result.context_messages is not None
|
||||
# 2 from transcript + 2 gap messages = 4 total
|
||||
assert len(result.context_messages) == 4
|
||||
roles = [m.role for m in result.context_messages]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
# Gap messages come from DB (ChatMessage objects)
|
||||
gap_user = result.context_messages[2]
|
||||
gap_asst = result.context_messages[3]
|
||||
assert gap_user.content == "GAP_USER_2"
|
||||
assert gap_asst.content == "GAP_ASSISTANT_3"
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk.
|
||||
|
||||
The fix is at the upload step: when use_resume=True and transcript_msg_count>0
|
||||
we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just
|
||||
recorded) instead of len(session.messages). This prevents the "inflated
|
||||
watermark" bug where a stale JSONL in GCS could hide missing context from
|
||||
future gap-fill checks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _compute_jsonl_covered(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_msg_count: int,
|
||||
) -> int:
|
||||
"""Mirror the watermark computation from ``stream_chat_completion_sdk``.
|
||||
|
||||
Extracted here so we can unit-test it independently without invoking the
|
||||
full streaming stack.
|
||||
"""
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
return transcript_msg_count + 2
|
||||
return session_msg_count
|
||||
|
||||
|
||||
class TestWatermarkFix:
|
||||
"""Watermark computation logic — mirrors the finally-block in SDK service."""
|
||||
|
||||
def test_inflated_watermark_triggers_gap_fill(self):
|
||||
"""Stale JSONL (T12) with high watermark (46) → after fix, watermark=14.
|
||||
|
||||
Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1)
|
||||
never fires because 46 >= 47-1=46, so context loss is silent.
|
||||
After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and
|
||||
the model receives the missing turns.
|
||||
"""
|
||||
# Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47
|
||||
use_resume = True
|
||||
transcript_msg_count = 12
|
||||
session_msg_count = 47 # DB count (what old code used to set watermark)
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 14 # 12 + 2, NOT 47
|
||||
# Verify: the gap check would fire on next turn
|
||||
# next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True
|
||||
assert watermark < session_msg_count - 1
|
||||
|
||||
def test_no_false_positive_when_transcript_current(self):
|
||||
"""Transcript current (watermark=46, DB=47) → gap stays 0.
|
||||
|
||||
When the JSONL actually covers T46 (the most recent assistant turn),
|
||||
uploading watermark=46+2=48 means next turn's gap check sees
|
||||
48 >= 48-1=47 → no gap. Correct.
|
||||
"""
|
||||
use_resume = True
|
||||
transcript_msg_count = 46
|
||||
session_msg_count = 47
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == 48 # 46 + 2
|
||||
# Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap
|
||||
next_turn_session = 48
|
||||
assert watermark >= next_turn_session - 1
|
||||
|
||||
def test_fresh_session_falls_back_to_db_count(self):
|
||||
"""use_resume=False → watermark = len(session.messages) (original behaviour)."""
|
||||
use_resume = False
|
||||
transcript_msg_count = 0
|
||||
session_msg_count = 3
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
|
||||
def test_old_format_meta_zero_count_falls_back_to_db(self):
|
||||
"""transcript_msg_count=0 (old-format meta with no count field) → DB fallback."""
|
||||
use_resume = True
|
||||
transcript_msg_count = 0 # old-format meta or not-yet-set
|
||||
session_msg_count = 10
|
||||
|
||||
watermark = _compute_jsonl_covered(
|
||||
use_resume, transcript_msg_count, session_msg_count
|
||||
)
|
||||
|
||||
assert watermark == session_msg_count
|
||||
@@ -12,18 +12,20 @@ from backend.copilot.transcript import (
|
||||
ENTRY_TYPE_MESSAGE,
|
||||
STOP_REASON_END_TURN,
|
||||
STRIPPABLE_TYPES,
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
TranscriptDownload,
|
||||
TranscriptMode,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
delete_transcript,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
extract_context_messages,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
strip_for_upload,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
@@ -34,18 +36,20 @@ __all__ = [
|
||||
"ENTRY_TYPE_MESSAGE",
|
||||
"STOP_REASON_END_TURN",
|
||||
"STRIPPABLE_TYPES",
|
||||
"TRANSCRIPT_STORAGE_PREFIX",
|
||||
"TranscriptDownload",
|
||||
"TranscriptMode",
|
||||
"cleanup_stale_project_dirs",
|
||||
"cli_session_path",
|
||||
"compact_transcript",
|
||||
"delete_transcript",
|
||||
"detect_gap",
|
||||
"download_transcript",
|
||||
"extract_context_messages",
|
||||
"projects_base",
|
||||
"read_compacted_entries",
|
||||
"restore_cli_session",
|
||||
"strip_for_upload",
|
||||
"strip_progress_entries",
|
||||
"strip_stale_thinking_blocks",
|
||||
"upload_cli_session",
|
||||
"upload_transcript",
|
||||
"validate_transcript",
|
||||
"write_transcript_to_tempfile",
|
||||
|
||||
@@ -297,8 +297,8 @@ class TestStripProgressEntries:
|
||||
|
||||
class TestDeleteTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_both_jsonl_and_meta(self):
|
||||
"""delete_transcript removes both the .jsonl and .meta.json files."""
|
||||
async def test_deletes_cli_session_and_meta(self):
|
||||
"""delete_transcript removes the CLI session .jsonl and .meta.json."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock()
|
||||
|
||||
@@ -309,7 +309,7 @@ class TestDeleteTranscript:
|
||||
):
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
paths = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert any(p.endswith(".jsonl") for p in paths)
|
||||
assert any(p.endswith(".meta.json") for p in paths)
|
||||
@@ -319,7 +319,7 @@ class TestDeleteTranscript:
|
||||
"""If .jsonl delete fails, .meta.json delete is still attempted."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[Exception("jsonl delete failed"), None, None]
|
||||
side_effect=[Exception("jsonl delete failed"), None]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -330,14 +330,14 @@ class TestDeleteTranscript:
|
||||
# Should not raise
|
||||
await delete_transcript("user-123", "session-456")
|
||||
|
||||
assert mock_storage.delete.call_count == 3
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_meta_delete_failure(self):
|
||||
"""If .meta.json delete fails, no exception propagates."""
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.delete = AsyncMock(
|
||||
side_effect=[None, Exception("meta delete failed"), None]
|
||||
side_effect=[None, Exception("meta delete failed")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: nonexistent,
|
||||
)
|
||||
|
||||
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
"backend.copilot.transcript.projects_base",
|
||||
lambda: str(projects_dir),
|
||||
)
|
||||
|
||||
@@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks:
|
||||
# Both entries of last turn (msg_last) preserved
|
||||
assert lines[1]["message"]["content"][0]["type"] == "thinking"
|
||||
assert lines[2]["message"]["content"][0]["type"] == "text"
|
||||
|
||||
|
||||
class TestProcessCliRestore:
|
||||
"""``_process_cli_restore`` validates, strips, and writes CLI session to disk."""
|
||||
|
||||
def test_writes_stripped_bytes_not_raw(self, tmp_path):
|
||||
"""Stripped bytes (not raw bytes) must be written to disk for --resume."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import _process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
session_id = "12345678-0000-0000-0000-abcdef000001"
|
||||
sdk_cwd = str(tmp_path)
|
||||
projects_base_dir = str(tmp_path)
|
||||
|
||||
# Build raw content with a strippable progress entry + a valid user/assistant pair
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
raw_bytes = raw_content.encode("utf-8")
|
||||
restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
stripped_str, ok = _process_cli_restore(
|
||||
restore, sdk_cwd, session_id, "[Test]"
|
||||
)
|
||||
|
||||
assert ok, "Expected successful restore"
|
||||
|
||||
# Find the written session file
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl"
|
||||
assert session_file.exists(), "Session file should have been written"
|
||||
|
||||
written_bytes = session_file.read_bytes()
|
||||
# The written bytes must be the stripped version (no progress entry)
|
||||
assert (
|
||||
b"progress" not in written_bytes
|
||||
), "Raw bytes with progress entry should not have been written"
|
||||
assert (
|
||||
b"hello" in written_bytes
|
||||
), "Stripped content should still contain assistant turn"
|
||||
|
||||
# Written bytes must equal the stripped string re-encoded
|
||||
assert written_bytes == stripped_str.encode(
|
||||
"utf-8"
|
||||
), "Written bytes must equal stripped content"
|
||||
|
||||
def test_invalid_content_returns_false(self):
|
||||
"""Content that fails validation after strip returns (empty, False)."""
|
||||
from backend.copilot.sdk.service import _process_cli_restore
|
||||
from backend.copilot.transcript import TranscriptDownload
|
||||
|
||||
# A single progress-only entry — stripped result will be empty/invalid
|
||||
raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
restore = TranscriptDownload(
|
||||
content=raw_content.encode("utf-8"), message_count=1, mode="sdk"
|
||||
)
|
||||
|
||||
stripped_str, ok = _process_cli_restore(
|
||||
restore,
|
||||
"/tmp/nonexistent-sdk-cwd",
|
||||
"12345678-0000-0000-0000-000000000099",
|
||||
"[Test]",
|
||||
)
|
||||
|
||||
assert not ok
|
||||
assert stripped_str == ""
|
||||
|
||||
|
||||
class TestReadCliSessionFromDisk:
|
||||
"""``_read_cli_session_from_disk`` reads, strips, and optionally writes back the session."""
|
||||
|
||||
def _build_session_file(self, tmp_path, session_id: str):
|
||||
"""Build the session file path inside tmp_path using the same encoding as cli_session_path."""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
sdk_cwd = str(tmp_path)
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_dir = Path(str(tmp_path)) / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
return sdk_cwd, session_dir / f"{session_id}.jsonl"
|
||||
|
||||
def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path):
|
||||
"""Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import _read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0001"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Write raw invalid UTF-8 bytes
|
||||
session_file.write_bytes(b"\xff\xfe invalid utf-8\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
|
||||
# UnicodeDecodeError path returns the raw bytes (upload-raw fallback)
|
||||
assert result == b"\xff\xfe invalid utf-8\n"
|
||||
|
||||
def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path):
|
||||
"""OSError on write-back returns stripped bytes for GCS upload (not raw)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.copilot.sdk.service import _read_cli_session_from_disk
|
||||
|
||||
session_id = "12345678-0000-0000-0000-aabbccdd0002"
|
||||
projects_base_dir = str(tmp_path)
|
||||
sdk_cwd, session_file = self._build_session_file(tmp_path, session_id)
|
||||
|
||||
# Content with a strippable progress entry so stripped_bytes < raw_bytes
|
||||
raw_content = (
|
||||
'{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n'
|
||||
'{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n'
|
||||
'{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n'
|
||||
)
|
||||
session_file.write_bytes(raw_content.encode("utf-8"))
|
||||
# Make the file read-only so write_bytes raises OSError on the write-back
|
||||
session_file.chmod(0o444)
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.service.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.projects_base",
|
||||
return_value=projects_base_dir,
|
||||
),
|
||||
):
|
||||
result = _read_cli_session_from_disk(sdk_cwd, session_id, "[Test]")
|
||||
finally:
|
||||
session_file.chmod(0o644)
|
||||
|
||||
# Must return stripped bytes (not raw, not None) so GCS gets the clean version
|
||||
assert result is not None
|
||||
assert (
|
||||
b"progress" not in result
|
||||
), "Stripped bytes must not contain progress entry"
|
||||
assert b"hello" in result, "Stripped bytes should contain assistant turn"
|
||||
|
||||
@@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
# (CLI version, platform). When that happens, multi-turn still works
|
||||
# via conversation compression (non-resume path), but we can't test
|
||||
# the --resume round-trip.
|
||||
transcript = None
|
||||
cli_session = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
cli_session = await download_transcript(test_user_id, session.session_id)
|
||||
# Wait until both the session bytes AND the message_count watermark are
|
||||
# present — a session with message_count=0 means the .meta.json hasn't
|
||||
# been uploaded yet, so --resume on the next turn would skip gap-fill.
|
||||
if cli_session and cli_session.message_count > 0:
|
||||
break
|
||||
if not transcript:
|
||||
if not cli_session:
|
||||
return pytest.skip(
|
||||
"CLI did not produce a usable transcript — "
|
||||
"cannot test --resume round-trip in this environment"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
logger.info(
|
||||
f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}"
|
||||
)
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""JSONL transcript management for stateless multi-turn resume.
|
||||
|
||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
||||
next turn we download the transcript, write it to a temp file, and pass
|
||||
``--resume`` so the CLI can reconstruct the full conversation.
|
||||
line). When the SDK's ``Stop`` hook fires the caller reads this file, strips
|
||||
bloat (progress entries, metadata), and uploads the result to bucket storage.
|
||||
On the next turn the caller downloads the bytes and writes them to disk before
|
||||
passing ``--resume`` so the CLI can reconstruct the full conversation.
|
||||
|
||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
||||
filesystem for self-hosted) — no DB column needed.
|
||||
@@ -20,6 +20,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset(
|
||||
)
|
||||
|
||||
|
||||
TranscriptMode = Literal["sdk", "baseline"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
content: bytes | str
|
||||
message_count: int = 0
|
||||
# "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
|
||||
mode: TranscriptMode = "sdk"
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
|
||||
_CLI_SESSION_STORAGE_PREFIX = "cli-sessions"
|
||||
|
||||
@@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _projects_base() -> str:
|
||||
def projects_base() -> str:
|
||||
"""Return the resolved path to the CLI's projects directory."""
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
return os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
@@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
Returns the number of directories removed.
|
||||
"""
|
||||
projects_base = _projects_base()
|
||||
if not os.path.isdir(projects_base):
|
||||
_pbase = projects_base()
|
||||
if not os.path.isdir(_pbase):
|
||||
return 0
|
||||
|
||||
now = time.time()
|
||||
@@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
|
||||
# Scoped mode: only clean up the one directory for the current session.
|
||||
if encoded_cwd:
|
||||
target = Path(projects_base) / encoded_cwd
|
||||
target = Path(_pbase) / encoded_cwd
|
||||
if not target.is_dir():
|
||||
return 0
|
||||
# Guard: only sweep copilot-generated dirs.
|
||||
@@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int:
|
||||
# Only safe for single-tenant deployments; callers should prefer the
|
||||
# scoped variant by passing encoded_cwd.
|
||||
try:
|
||||
entries = Path(projects_base).iterdir()
|
||||
entries = Path(_pbase).iterdir()
|
||||
except OSError as e:
|
||||
logger.warning("[Transcript] Failed to list projects dir: %s", e)
|
||||
return 0
|
||||
@@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None:
|
||||
if not transcript_path:
|
||||
return None
|
||||
|
||||
projects_base = _projects_base()
|
||||
_pbase = projects_base()
|
||||
real_path = os.path.realpath(transcript_path)
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
if not real_path.startswith(_pbase + os.sep):
|
||||
logger.warning(
|
||||
"[Transcript] transcript_path outside projects base: %s", transcript_path
|
||||
)
|
||||
@@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
||||
|
||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
||||
"""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
"""Build a full storage path from (workspace_id, file_id, filename) parts."""
|
||||
wid, fid, fname = parts
|
||||
@@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects."""
|
||||
return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend)
|
||||
|
||||
|
||||
def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path for the companion .meta.json file."""
|
||||
return _build_path_from_parts(
|
||||
_meta_storage_path_parts(user_id, session_id), backend
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI native session file — cross-pod --resume support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
def cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""Expected path of the CLI's native session JSONL file.
|
||||
|
||||
The CLI resolves the working directory via ``os.path.realpath``, then
|
||||
@@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str:
|
||||
"""
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
safe_id = _sanitize_id(session_id)
|
||||
return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl")
|
||||
|
||||
|
||||
def _cli_session_storage_path_parts(
|
||||
@@ -689,235 +659,82 @@ def _cli_session_storage_path_parts(
|
||||
)
|
||||
|
||||
|
||||
async def upload_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Upload the CLI's native session JSONL file to remote storage.
|
||||
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
|
||||
The CLI only writes the session file after the turn completes, so this
|
||||
must run in the finally block, AFTER the SDK stream has finished.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
os.path.basename(real_path),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
raw_bytes = Path(real_path).read_bytes()
|
||||
except FileNotFoundError:
|
||||
logger.debug(
|
||||
"%s CLI session file not found, skipping upload: %s",
|
||||
log_prefix,
|
||||
session_file,
|
||||
)
|
||||
return
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to read CLI session file: %s", log_prefix, e)
|
||||
return
|
||||
|
||||
# Strip stale thinking blocks and metadata entries (progress, file-history-snapshot,
|
||||
# queue-operation) from the CLI session before writing it back locally and uploading
|
||||
# to GCS. Thinking blocks from non-last assistant turns are not needed for --resume
|
||||
# but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact
|
||||
# its session when the context window fills up. Stripping keeps the session well below
|
||||
# the ~200K-token compaction threshold and prevents silent context loss.
|
||||
try:
|
||||
raw_text = raw_bytes.decode("utf-8")
|
||||
stripped_text = strip_for_upload(raw_text)
|
||||
stripped_bytes = stripped_text.encode("utf-8")
|
||||
if len(stripped_bytes) < len(raw_bytes):
|
||||
# Write the stripped version back locally so same-pod turns also benefit.
|
||||
Path(real_path).write_bytes(stripped_bytes)
|
||||
logger.info(
|
||||
"%s Stripped CLI session file: %dB → %dB",
|
||||
log_prefix,
|
||||
len(raw_bytes),
|
||||
len(stripped_bytes),
|
||||
)
|
||||
content = stripped_bytes
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e
|
||||
)
|
||||
content = raw_bytes
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
logger.info(
|
||||
"%s Uploaded CLI session file (%dB) for cross-pod --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e)
|
||||
|
||||
|
||||
async def restore_cli_session(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> bool:
|
||||
"""Download and restore the CLI's native session file for --resume.
|
||||
|
||||
Returns True if the file was successfully restored and --resume can be
|
||||
used with the session UUID. Returns False if not available (first turn
|
||||
or upload failed), in which case the caller should not set --resume.
|
||||
"""
|
||||
session_file = _cli_session_path(sdk_cwd, session_id)
|
||||
real_path = os.path.realpath(session_file)
|
||||
projects_base = _projects_base()
|
||||
|
||||
if not real_path.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
"%s CLI session restore path outside projects base: %s",
|
||||
log_prefix,
|
||||
os.path.basename(session_file),
|
||||
)
|
||||
return False
|
||||
|
||||
# If the session file already exists locally (same-pod reuse), use it directly.
|
||||
# Downloading from storage could overwrite a newer local version when a previous
|
||||
# turn's upload failed: stored content is stale while the local file already
|
||||
# contains extended history from that turn.
|
||||
if Path(real_path).exists():
|
||||
logger.debug(
|
||||
"%s CLI session file already exists locally — using it for --resume",
|
||||
log_prefix,
|
||||
)
|
||||
return True
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for the CLI session meta file."""
|
||||
return (
|
||||
_CLI_SESSION_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await storage.retrieve(path)
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download CLI session: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(real_path), exist_ok=True)
|
||||
Path(real_path).write_bytes(content)
|
||||
logger.info(
|
||||
"%s Restored CLI session file (%dB) for --resume",
|
||||
log_prefix,
|
||||
len(content),
|
||||
)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning("%s Failed to write CLI session file: %s", log_prefix, e)
|
||||
return False
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
content: bytes,
|
||||
message_count: int = 0,
|
||||
mode: TranscriptMode = "sdk",
|
||||
log_prefix: str = "[Transcript]",
|
||||
skip_strip: bool = False,
|
||||
) -> None:
|
||||
"""Strip progress entries and stale thinking blocks, then upload transcript.
|
||||
"""Upload CLI session content to GCS with companion meta.json.
|
||||
|
||||
The transcript represents the FULL active context (atomic).
|
||||
Each upload REPLACES the previous transcript entirely.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for reading
|
||||
the session file from disk before calling this function.
|
||||
|
||||
The executor holds a cluster lock per session, so concurrent uploads for
|
||||
the same session cannot happen.
|
||||
Also uploads a companion .meta.json with the message_count watermark so
|
||||
download_transcript can return it without a separate fetch.
|
||||
|
||||
Args:
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
skip_strip: When ``True``, skip the strip + re-validate pass.
|
||||
Safe for builder-generated content (baseline path) which
|
||||
never emits progress entries or stale thinking blocks.
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
"""
|
||||
if skip_strip:
|
||||
# Caller guarantees the content is already clean and valid.
|
||||
stripped = content
|
||||
else:
|
||||
# Strip metadata entries and stale thinking blocks in a single parse.
|
||||
# SDK-built transcripts may have progress entries; strip for safety.
|
||||
stripped = strip_for_upload(content)
|
||||
if not skip_strip and not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
for line in stripped.strip().split("\n")
|
||||
]
|
||||
logger.warning(
|
||||
"%s Skipping upload — stripped content not valid "
|
||||
"(types=%s, stripped_len=%d, raw_len=%d)",
|
||||
log_prefix,
|
||||
entry_types,
|
||||
len(stripped),
|
||||
len(content),
|
||||
)
|
||||
logger.debug("%s Raw content preview: %s", log_prefix, content[:500])
|
||||
logger.debug("%s Stripped content: %s", log_prefix, stripped[:500])
|
||||
return
|
||||
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
|
||||
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
|
||||
meta_encoded = json.dumps(meta).encode("utf-8")
|
||||
|
||||
# Transcript + metadata are independent objects at different keys, so
|
||||
# write them concurrently. ``return_exceptions`` keeps a metadata
|
||||
# failure from sinking the transcript write.
|
||||
transcript_result, metadata_result = await asyncio.gather(
|
||||
storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
),
|
||||
storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=meta_encoded,
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(transcript_result, BaseException):
|
||||
raise transcript_result
|
||||
if isinstance(metadata_result, BaseException):
|
||||
# Metadata is best-effort — the gap-fill logic in
|
||||
# _build_query_message tolerates a missing metadata file.
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
|
||||
# Write JSONL first, meta second — sequential so a crash between the two
|
||||
# leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong
|
||||
# watermark / mode paired with stale or absent content).
|
||||
# On any failure we roll back the other file so the pair is always absent
|
||||
# together; download_transcript returns None when either file is missing.
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=wid, file_id=fid, filename=fname, content=content
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.warning(
|
||||
"%s Failed to upload CLI session file: %s", log_prefix, session_err
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await storage.store(
|
||||
workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded
|
||||
)
|
||||
except Exception as meta_err:
|
||||
logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err)
|
||||
# Roll back the JSONL so neither file exists — avoids orphaned JSONL being
|
||||
# used with wrong mode/watermark defaults on the next restore.
|
||||
try:
|
||||
session_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(session_path)
|
||||
except Exception as rollback_err:
|
||||
logger.debug(
|
||||
"%s Session rollback failed (harmless — download will return None): %s",
|
||||
log_prefix,
|
||||
rollback_err,
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
"%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(encoded),
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -926,83 +743,173 @@ async def download_transcript(
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for writing
|
||||
content to disk if --resume is needed.
|
||||
|
||||
The content and metadata fetches run concurrently since they are
|
||||
independent objects in the bucket.
|
||||
Returns a TranscriptDownload with the raw content, message_count watermark,
|
||||
and mode on success, or None if not available (first turn or upload failed).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
|
||||
content_task = asyncio.create_task(storage.retrieve(path))
|
||||
meta_task = asyncio.create_task(storage.retrieve(meta_path))
|
||||
content_result, meta_result = await asyncio.gather(
|
||||
content_task, meta_task, return_exceptions=True
|
||||
storage.retrieve(path),
|
||||
storage.retrieve(meta_path),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if isinstance(content_result, FileNotFoundError):
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix)
|
||||
return None
|
||||
if isinstance(content_result, BaseException):
|
||||
logger.warning(
|
||||
"%s Failed to download transcript: %s", log_prefix, content_result
|
||||
"%s Failed to download CLI session: %s", log_prefix, content_result
|
||||
)
|
||||
return None
|
||||
|
||||
content = content_result.decode("utf-8")
|
||||
content: bytes = content_result
|
||||
|
||||
# Metadata is best-effort — old transcripts won't have it.
|
||||
# Parse message_count and mode from companion meta — best-effort, defaults.
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
mode: TranscriptMode = "sdk"
|
||||
if isinstance(meta_result, FileNotFoundError):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
pass # No meta — old upload; default to "sdk"
|
||||
elif isinstance(meta_result, BaseException):
|
||||
logger.debug(
|
||||
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
|
||||
)
|
||||
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
|
||||
else:
|
||||
meta = json.loads(meta_result.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
try:
|
||||
meta_str = meta_result.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix)
|
||||
meta_str = None
|
||||
if meta_str is not None:
|
||||
meta = json.loads(meta_str, fallback={})
|
||||
if isinstance(meta, dict):
|
||||
raw_count = meta.get("message_count", 0)
|
||||
message_count = (
|
||||
raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0
|
||||
)
|
||||
raw_mode = meta.get("mode", "sdk")
|
||||
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
"%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)",
|
||||
log_prefix,
|
||||
len(content),
|
||||
message_count,
|
||||
mode,
|
||||
)
|
||||
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
|
||||
|
||||
|
||||
def detect_gap(
|
||||
download: TranscriptDownload,
|
||||
session_messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Return chat-db messages after the transcript watermark (excluding current user turn).
|
||||
|
||||
Returns [] if transcript is current, watermark is zero, or the watermark
|
||||
position doesn't end on an assistant turn (misaligned watermark).
|
||||
"""
|
||||
if download.message_count == 0:
|
||||
return []
|
||||
wm = download.message_count
|
||||
total = len(session_messages)
|
||||
if wm >= total - 1:
|
||||
return []
|
||||
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
|
||||
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
|
||||
# In normal operation ``message_count`` is always written after a complete
|
||||
# user→assistant exchange (never mid-turn), so the last covered position is
|
||||
# always assistant. This guard fires only on data corruption or message deletion.
|
||||
if session_messages[wm - 1].role != "assistant":
|
||||
return []
|
||||
return list(session_messages[wm : total - 1])
|
||||
|
||||
|
||||
def extract_context_messages(
|
||||
download: TranscriptDownload | None,
|
||||
session_messages: "list[ChatMessage]",
|
||||
) -> "list[ChatMessage]":
|
||||
"""Return context messages for the current turn: transcript content + gap.
|
||||
|
||||
This is the shared context primitive used by both the SDK path
|
||||
(``use_resume=False`` → ``<conversation_history>`` injection) and the
|
||||
baseline path (OpenAI messages array).
|
||||
|
||||
How it works:
|
||||
|
||||
- When a transcript exists, ``TranscriptBuilder.load_previous`` preserves
|
||||
``isCompactSummary=True`` compaction entries, so the returned messages
|
||||
mirror the compacted context the CLI would see via ``--resume``.
|
||||
- The gap (DB messages after the transcript watermark) is always small in
|
||||
normal operation; it only grows during mode switches or when an upload
|
||||
was missed.
|
||||
- Falls back to full DB messages when no transcript exists (first turn,
|
||||
upload failure, or GCS unavailable).
|
||||
- Returns *prior* messages only (excluding the current user turn at
|
||||
``session_messages[-1]``). Callers that need the current turn append
|
||||
``session_messages[-1]`` themselves.
|
||||
- **Tool calls from transcript entries are flattened to text**: assistant
|
||||
messages derived from the JSONL use ``_flatten_assistant_content``, which
|
||||
serialises ``tool_use`` blocks as human-readable text rather than
|
||||
structured ``tool_calls``. Gap messages (from DB) preserve their
|
||||
original ``tool_calls`` field. This is the same trade-off as the old
|
||||
``_compress_session_messages(session.messages)`` approach — no regression.
|
||||
|
||||
Args:
|
||||
download: The ``TranscriptDownload`` from GCS, or ``None`` when no
|
||||
transcript is available. ``content`` may be either ``bytes`` or
|
||||
``str`` (the baseline path decodes + strips before returning).
|
||||
session_messages: All messages in the session, with the current user
|
||||
turn as the last element.
|
||||
|
||||
Returns:
|
||||
A list of ``ChatMessage`` objects covering the prior conversation
|
||||
context, suitable for injection as conversation history.
|
||||
"""
|
||||
from .model import ChatMessage as _ChatMessage # runtime import
|
||||
|
||||
prior = session_messages[:-1]
|
||||
|
||||
if download is None:
|
||||
return prior
|
||||
|
||||
raw_content = download.content
|
||||
if not raw_content:
|
||||
return prior
|
||||
|
||||
# Handle both bytes (raw GCS download) and str (pre-decoded baseline path).
|
||||
if isinstance(raw_content, bytes):
|
||||
try:
|
||||
content_str: str = raw_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return prior
|
||||
else:
|
||||
content_str = raw_content
|
||||
|
||||
raw = _transcript_to_messages(content_str)
|
||||
if not raw:
|
||||
return prior
|
||||
|
||||
transcript_msgs = [
|
||||
_ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw
|
||||
]
|
||||
gap = detect_gap(download, session_messages)
|
||||
return transcript_msgs + gap
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript and its metadata from bucket storage.
|
||||
|
||||
Removes both the ``.jsonl`` transcript and the companion ``.meta.json``
|
||||
so stale ``message_count`` watermarks cannot corrupt gap-fill logic.
|
||||
"""
|
||||
"""Delete CLI session JSONL and its companion .meta.json from bucket storage."""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
await storage.delete(path)
|
||||
logger.info("[Transcript] Deleted transcript for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete transcript: %s", e)
|
||||
|
||||
# Also delete the companion .meta.json to avoid orphaned metadata.
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
await storage.delete(meta_path)
|
||||
logger.info("[Transcript] Deleted metadata for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete metadata: %s", e)
|
||||
|
||||
# Also delete the CLI native session file to prevent storage growth.
|
||||
try:
|
||||
cli_path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
@@ -1012,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session: %s", e)
|
||||
|
||||
try:
|
||||
cli_meta_path = _build_path_from_parts(
|
||||
_cli_session_meta_path_parts(user_id, session_id), storage
|
||||
)
|
||||
await storage.delete(cli_meta_path)
|
||||
logger.info("[Transcript] Deleted CLI session meta for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("[Transcript] Failed to delete CLI session meta: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transcript compaction — LLM summarization for prompt-too-long recovery
|
||||
|
||||
@@ -16,11 +16,11 @@ from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_meta_storage_path_parts,
|
||||
_rechain_tail,
|
||||
_sanitize_id,
|
||||
_storage_path_parts,
|
||||
_transcript_to_messages,
|
||||
detect_gap,
|
||||
extract_context_messages,
|
||||
strip_for_upload,
|
||||
validate_transcript,
|
||||
)
|
||||
@@ -64,24 +64,6 @@ class TestSanitizeId:
|
||||
assert _sanitize_id("!@#$%^&*()") == "unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _storage_path_parts / _meta_storage_path_parts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStoragePathParts:
|
||||
def test_returns_triple(self):
|
||||
prefix, uid, fname = _storage_path_parts("user-1", "sess-2")
|
||||
assert prefix == "chat-transcripts"
|
||||
assert "e" in uid # hex chars from "user-1" sanitized
|
||||
assert fname.endswith(".jsonl")
|
||||
|
||||
def test_meta_returns_meta_json(self):
|
||||
prefix, _, fname = _meta_storage_path_parts("user-1", "sess-2")
|
||||
assert prefix == "chat-transcripts"
|
||||
assert fname.endswith(".meta.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_path_from_parts
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -103,24 +85,6 @@ class TestBuildPathFromParts:
|
||||
assert path == "local://wid/fid/file.jsonl"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TranscriptDownload dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscriptDownload:
|
||||
def test_defaults(self):
|
||||
td = TranscriptDownload(content="hello")
|
||||
assert td.content == "hello"
|
||||
assert td.message_count == 0
|
||||
assert td.uploaded_at == 0.0
|
||||
|
||||
def test_custom_values(self):
|
||||
td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45)
|
||||
assert td.message_count == 5
|
||||
assert td.uploaded_at == 123.45
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -733,190 +697,188 @@ class TestValidateTranscript:
|
||||
|
||||
class TestCliSessionPath:
|
||||
def test_encodes_slashes_to_dashes(self):
|
||||
from .transcript import _cli_session_path, _projects_base
|
||||
from .transcript import cli_session_path, projects_base
|
||||
|
||||
sdk_cwd = "/tmp/copilot-abc"
|
||||
result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc")
|
||||
base = _projects_base()
|
||||
result = cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc")
|
||||
base = projects_base()
|
||||
assert result.startswith(base)
|
||||
# Encoded cwd replaces '/' with '-'
|
||||
assert "-tmp-copilot-abc" in result
|
||||
assert result.endswith(".jsonl")
|
||||
|
||||
def test_sanitizes_session_id(self):
|
||||
from .transcript import _cli_session_path
|
||||
from .transcript import cli_session_path
|
||||
|
||||
result = _cli_session_path("/tmp/cwd", "../../etc/passwd")
|
||||
result = cli_session_path("/tmp/cwd", "../../etc/passwd")
|
||||
# _sanitize_id strips non-hex/hyphen chars; path traversal impossible
|
||||
assert ".." not in result
|
||||
assert "passwd" not in result
|
||||
|
||||
|
||||
class TestUploadCliSession:
|
||||
def test_skips_upload_when_path_outside_projects_base(self, tmp_path):
|
||||
"""Files outside the CLI projects base are rejected without upload."""
|
||||
def test_uploads_content_bytes_successfully(self):
|
||||
"""Happy path: content bytes are stored as jsonl + meta.json."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
# Return a path that is genuinely outside tmp_path so that
|
||||
# realpath(session_file).startswith(projects_base + "/") is False
|
||||
# and the boundary guard actually fires.
|
||||
patch(
|
||||
"backend.copilot.transcript._cli_session_path",
|
||||
return_value="/outside/escaped/session.jsonl",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd=str(tmp_path),
|
||||
session_id="12345678-0000-0000-0000-000000000001",
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
# storage.store must NOT be called — boundary guard should reject the path
|
||||
mock_storage.store.assert_not_called()
|
||||
# Two calls expected: session JSONL + companion .meta.json
|
||||
assert mock_storage.store.call_count == 2
|
||||
|
||||
def test_skips_upload_when_file_not_found(self, tmp_path):
|
||||
"""Missing CLI session file logs debug and skips upload silently."""
|
||||
def test_uploads_companion_meta_json_with_message_count(self):
|
||||
"""upload_transcript stores a companion .meta.json with message_count."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000010",
|
||||
content=content,
|
||||
message_count=5,
|
||||
)
|
||||
)
|
||||
|
||||
assert mock_storage.store.call_count == 2
|
||||
# Find the meta.json store call
|
||||
meta_call = next(
|
||||
c
|
||||
for c in mock_storage.store.call_args_list
|
||||
if c.kwargs.get("filename", "").endswith(".meta.json")
|
||||
)
|
||||
meta_content = json.loads(meta_call.kwargs["content"])
|
||||
assert meta_content["message_count"] == 5
|
||||
|
||||
def test_skips_upload_on_storage_failure(self):
|
||||
"""Storage exception on jsonl write is logged and does not propagate.
|
||||
|
||||
With sequential writes, JSONL failure returns early — meta store is
|
||||
never called, so no rollback is needed.
|
||||
"""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
projects_base = str(tmp_path)
|
||||
mock_storage.store.side_effect = RuntimeError("gcs unavailable")
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
# session file doesn't exist — should not raise
|
||||
# Should not raise — failures are logged as warnings
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd=str(tmp_path),
|
||||
)
|
||||
)
|
||||
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
def test_uploads_file_successfully(self, tmp_path):
|
||||
"""Happy path: session file exists within projects base → upload called."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import _sanitize_id, upload_cli_session
|
||||
|
||||
projects_base = str(tmp_path)
|
||||
session_id = "12345678-0000-0000-0000-000000000001"
|
||||
sdk_cwd = str(tmp_path)
|
||||
|
||||
# Build the path the same way _cli_session_path does, but using our tmp_path
|
||||
# as projects_base so the boundary check passes.
|
||||
# Must use the same encoding: re.sub non-alphanumeric → "-" on realpath.
|
||||
import os
|
||||
import re
|
||||
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_dir = tmp_path / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
|
||||
session_file.write_bytes(b'{"type":"assistant"}\n')
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
session_id="12345678-0000-0000-0000-000000000002",
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
# Only one store call attempted (the JSONL); meta never reached
|
||||
mock_storage.store.assert_called_once()
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
def test_skips_upload_on_oserror(self, tmp_path):
|
||||
"""OSError reading session file is logged as warning; upload is skipped."""
|
||||
def test_rolls_back_session_when_meta_upload_fails(self):
|
||||
"""When meta upload fails after JSONL succeeds, JSONL is rolled back.
|
||||
|
||||
Guarantees the pair is either both present or both absent — avoids an
|
||||
orphaned JSONL being used with wrong mode/watermark defaults.
|
||||
"""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import _sanitize_id, upload_cli_session
|
||||
|
||||
projects_base = str(tmp_path)
|
||||
sdk_cwd = str(tmp_path)
|
||||
session_id = "12345678-0000-0000-0000-000000000002"
|
||||
|
||||
# Build file at a path inside projects_base so boundary check passes.
|
||||
import os
|
||||
import re
|
||||
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd))
|
||||
session_dir = tmp_path / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl"
|
||||
session_file.write_bytes(b'{"type":"assistant"}\n')
|
||||
# Remove read permission to trigger OSError
|
||||
session_file.chmod(0o000)
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
# First store (JSONL) succeeds; second store (meta) fails
|
||||
mock_storage.store.side_effect = [None, RuntimeError("meta write failed")]
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000099",
|
||||
content=content,
|
||||
)
|
||||
finally:
|
||||
session_file.chmod(0o644) # restore so tmp_path cleanup works
|
||||
)
|
||||
|
||||
mock_storage.store.assert_not_called()
|
||||
# Both store calls were attempted (JSONL then meta)
|
||||
assert mock_storage.store.call_count == 2
|
||||
# JSONL should be rolled back via delete
|
||||
mock_storage.delete.assert_called_once()
|
||||
|
||||
def test_baseline_mode_stored_in_meta(self):
|
||||
"""upload_transcript with mode='baseline' stores mode in companion meta.json."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000098",
|
||||
content=content,
|
||||
message_count=4,
|
||||
mode="baseline",
|
||||
)
|
||||
)
|
||||
|
||||
meta_call = next(
|
||||
c
|
||||
for c in mock_storage.store.call_args_list
|
||||
if c.kwargs.get("filename", "").endswith(".meta.json")
|
||||
)
|
||||
meta_content = json.loads(meta_call.kwargs["content"])
|
||||
assert meta_content["mode"] == "baseline"
|
||||
assert meta_content["message_count"] == 4
|
||||
|
||||
def test_strips_session_before_upload_and_writes_back(self, tmp_path):
|
||||
"""Strippable entries (progress, thinking blocks) are removed before upload.
|
||||
@@ -1116,15 +1078,18 @@ class TestUploadCliSession:
|
||||
|
||||
|
||||
class TestRestoreCliSession:
|
||||
def test_returns_false_when_file_not_found_in_storage(self):
|
||||
"""Returns False (graceful degradation) when the session is missing."""
|
||||
def test_returns_none_when_file_not_found_in_storage(self):
|
||||
"""Returns None (graceful degradation) when the session is missing."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = FileNotFoundError("not found")
|
||||
mock_storage.retrieve.side_effect = [
|
||||
FileNotFoundError("no session"),
|
||||
FileNotFoundError("no meta"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
@@ -1132,144 +1097,26 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd="/tmp/copilot-test",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is False
|
||||
assert result is None
|
||||
|
||||
def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path):
|
||||
"""Path traversal guard: rejects restoration outside the projects base."""
|
||||
def test_returns_transcript_download_on_success_no_meta(self):
|
||||
"""Happy path with no meta.json: returns TranscriptDownload with message_count=0."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b'{"type":"assistant"}\n'
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=str(tmp_path),
|
||||
),
|
||||
# Return a path genuinely outside tmp_path so the boundary guard fires.
|
||||
patch(
|
||||
"backend.copilot.transcript._cli_session_path",
|
||||
return_value="/outside/escaped/session.jsonl",
|
||||
),
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd=str(tmp_path),
|
||||
)
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_local_file_already_exists(self, tmp_path):
|
||||
"""Same-pod reuse: if local file exists, skip storage download and return True."""
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
|
||||
session_id = "12345678-0000-0000-0000-000000000099"
|
||||
sdk_cwd = str(tmp_path)
|
||||
|
||||
# Pre-create the local session file (simulates previous turn on same pod)
|
||||
projects_base = os.path.realpath(str(tmp_path))
|
||||
encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", projects_base)
|
||||
session_dir = Path(projects_base) / encoded_cwd
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
existing_content = b'{"type":"user"}\n{"type":"assistant"}\n'
|
||||
(session_dir / f"{session_id}.jsonl").write_bytes(existing_content)
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# Storage should NOT have been accessed (local file was used as-is)
|
||||
mock_storage.retrieve.assert_not_called()
|
||||
# Local file should be unchanged
|
||||
assert (session_dir / f"{session_id}.jsonl").read_bytes() == existing_content
|
||||
|
||||
def test_returns_true_on_success(self, tmp_path):
|
||||
"""Happy path: storage has the session → file written → returns True."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
|
||||
projects_base = str(tmp_path)
|
||||
sdk_cwd = str(tmp_path)
|
||||
session_id = "12345678-0000-0000-0000-000000000003"
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = content
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_on_download_exception(self):
|
||||
"""Non-FileNotFoundError during retrieve logs warning and returns False."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = RuntimeError("network error")
|
||||
mock_storage.retrieve.side_effect = [content, FileNotFoundError("no meta")]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
@@ -1277,11 +1124,411 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000004",
|
||||
sdk_cwd="/tmp/copilot-test",
|
||||
session_id=session_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert result is False
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.content == content
|
||||
assert result.message_count == 0
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_returns_transcript_download_with_message_count_from_meta(self):
|
||||
"""When meta.json is present, message_count and mode are read from it."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
session_id = "12345678-0000-0000-0000-000000000005"
|
||||
content = b'{"type":"assistant"}\n'
|
||||
meta_bytes = json.dumps(
|
||||
{"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0}
|
||||
).encode()
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, meta_bytes]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.content == content
|
||||
assert result.message_count == 7
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_returns_none_on_download_exception(self):
|
||||
"""Non-FileNotFoundError during retrieve logs warning and returns None."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [
|
||||
RuntimeError("network error"),
|
||||
FileNotFoundError("no meta"),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000004",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_baseline_mode_in_meta_returned(self):
|
||||
"""When meta.json contains mode='baseline', result.mode is 'baseline'."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
content = b'{"type":"assistant"}\n'
|
||||
meta_bytes = json.dumps(
|
||||
{"message_count": 3, "mode": "baseline", "uploaded_at": 0.0}
|
||||
).encode()
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, meta_bytes]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000020",
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.mode == "baseline"
|
||||
assert result.message_count == 3
|
||||
|
||||
def test_invalid_mode_in_meta_defaults_to_sdk(self):
|
||||
"""Unknown mode value in meta.json falls back to 'sdk'."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
content = b'{"type":"assistant"}\n'
|
||||
meta_bytes = json.dumps({"message_count": 2, "mode": "unknown_mode"}).encode()
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, meta_bytes]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000021",
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_invalid_utf8_meta_uses_defaults(self):
|
||||
"""Meta bytes that fail UTF-8 decode fall back to message_count=0, mode='sdk'."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
content = b'{"type":"assistant"}\n'
|
||||
bad_meta = b"\xff\xfe"
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, bad_meta]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000022",
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.message_count == 0
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_meta_fetch_exception_uses_defaults(self):
|
||||
"""Non-FileNotFoundError on meta fetch still returns content with defaults."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import download_transcript
|
||||
|
||||
content = b'{"type":"assistant"}\n'
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, RuntimeError("meta unavailable")]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000023",
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.content == content
|
||||
assert result.message_count == 0
|
||||
assert result.mode == "sdk"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# detect_gap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _msgs(*roles: str):
|
||||
"""Build a list of ChatMessage objects with the given roles."""
|
||||
from .model import ChatMessage
|
||||
|
||||
return [ChatMessage(role=r, content=f"{r}-{i}") for i, r in enumerate(roles)]
|
||||
|
||||
|
||||
class TestDetectGap:
|
||||
"""``detect_gap`` returns messages between transcript watermark and current turn."""
|
||||
|
||||
def _dl(self, message_count: int) -> TranscriptDownload:
|
||||
return TranscriptDownload(content=b"", message_count=message_count, mode="sdk")
|
||||
|
||||
def test_zero_watermark_returns_empty(self):
|
||||
"""message_count=0 means no watermark — skip gap detection."""
|
||||
dl = self._dl(0)
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
assert detect_gap(dl, messages) == []
|
||||
|
||||
def test_watermark_covers_all_prefix_returns_empty(self):
|
||||
"""Transcript already covers all messages up to the current user turn."""
|
||||
# session: [user, assistant, user(current)] — wm=2 means covers up to assistant
|
||||
dl = self._dl(2)
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
assert detect_gap(dl, messages) == []
|
||||
|
||||
def test_watermark_exceeds_session_returns_empty(self):
|
||||
"""Watermark ahead of session count (race / over-count) → no gap."""
|
||||
dl = self._dl(10)
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
assert detect_gap(dl, messages) == []
|
||||
|
||||
def test_misaligned_watermark_not_on_assistant_returns_empty(self):
|
||||
"""Watermark at a user-role position is misaligned — skip gap."""
|
||||
# wm=1: position 0 is 'user', not 'assistant' → skip
|
||||
dl = self._dl(1)
|
||||
messages = _msgs("user", "assistant", "user", "assistant", "user")
|
||||
assert detect_gap(dl, messages) == []
|
||||
|
||||
def test_returns_gap_messages(self):
|
||||
"""Watermark behind session — gap messages returned (excluding current turn)."""
|
||||
# session: [user0, assistant1, user2, assistant3, user4(current)]
|
||||
# wm=2: transcript covers [0,1]; gap = [user2, assistant3]
|
||||
dl = self._dl(2)
|
||||
messages = _msgs("user", "assistant", "user", "assistant", "user")
|
||||
gap = detect_gap(dl, messages)
|
||||
assert len(gap) == 2
|
||||
assert gap[0].role == "user"
|
||||
assert gap[1].role == "assistant"
|
||||
|
||||
def test_excludes_current_user_turn(self):
|
||||
"""The last message (current user turn) is never included in the gap."""
|
||||
# wm=2, session has 4 msgs: gap = [msg2] only (msg3 is current turn → excluded)
|
||||
dl = self._dl(2)
|
||||
messages = _msgs("user", "assistant", "user", "user")
|
||||
gap = detect_gap(dl, messages)
|
||||
assert len(gap) == 1
|
||||
assert gap[0].role == "user"
|
||||
|
||||
def test_single_gap_message(self):
|
||||
"""One message between watermark and current turn."""
|
||||
# session: [user0, assistant1, user2, assistant3, user4(current)]
|
||||
# wm=3: position 2 is 'user' → misaligned, returns []
|
||||
# use wm=4: but 4 >= total-1=4 → also empty
|
||||
# wm=3 with session [u, a, u, a, u, a, u(current)]: position 2 is 'user' → empty
|
||||
# Valid case: wm=2 has 3 messages (assistant at 1), wm=4 with [u,a,u,a,u,a,u]:
|
||||
# let's use wm=4 with 7 messages: wm=4 >= total-1=6? no, 4<6. pos[3]=assistant → gap=[msg4,msg5]
|
||||
# simpler: wm=2, [u0,a1,a2,u3(current)] — pos[1]=assistant, gap=[a2] only
|
||||
dl = self._dl(2)
|
||||
messages = _msgs("user", "assistant", "assistant", "user")
|
||||
gap = detect_gap(dl, messages)
|
||||
assert len(gap) == 1
|
||||
assert gap[0].role == "assistant"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_context_messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_valid_transcript(*roles: str) -> str:
|
||||
"""Build a minimal valid JSONL transcript with the given message roles."""
|
||||
import json as stdlib_json
|
||||
|
||||
from .transcript import STOP_REASON_END_TURN
|
||||
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
uid = f"uid-{i}"
|
||||
entry: dict = {
|
||||
"type": role,
|
||||
"uuid": uid,
|
||||
"parentUuid": parent,
|
||||
"message": {
|
||||
"role": role,
|
||||
"content": f"{role} content {i}",
|
||||
},
|
||||
}
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
entry["message"]["content"] = [
|
||||
{"type": "text", "text": f"assistant content {i}"}
|
||||
]
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestExtractContextMessages:
|
||||
"""``extract_context_messages`` returns the shared context primitive."""
|
||||
|
||||
def test_none_download_returns_prior(self):
|
||||
"""No download → falls back to all session messages except current turn."""
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
result = extract_context_messages(None, messages)
|
||||
assert result == messages[:-1]
|
||||
assert len(result) == 2
|
||||
|
||||
def test_empty_content_download_returns_prior(self):
|
||||
"""Empty bytes content → falls back to all prior messages."""
|
||||
dl = TranscriptDownload(content=b"", message_count=2, mode="sdk")
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
result = extract_context_messages(dl, messages)
|
||||
assert result == messages[:-1]
|
||||
|
||||
def test_valid_transcript_no_gap_returns_transcript_messages(self):
|
||||
"""Transcript covers all prior turns → only transcript messages returned."""
|
||||
# Transcript: [user, assistant] — 2 messages
|
||||
# Session: [user, assistant, user(current)] — watermark=2 covers prefix
|
||||
transcript_content = _make_valid_transcript("user", "assistant")
|
||||
dl = TranscriptDownload(
|
||||
content=transcript_content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
result = extract_context_messages(dl, messages)
|
||||
# Transcript has 2 messages (user + assistant) and no gap
|
||||
assert len(result) == 2
|
||||
assert result[0].role == "user"
|
||||
assert result[1].role == "assistant"
|
||||
|
||||
def test_valid_transcript_with_gap_returns_transcript_plus_gap(self):
|
||||
"""Transcript is stale → gap messages appended after transcript content."""
|
||||
# Transcript: [user, assistant] — watermark=2
|
||||
# Session: [user, assistant, user, assistant, user(current)]
|
||||
# Gap: [user(2), assistant(3)] — positions 2 and 3
|
||||
transcript_content = _make_valid_transcript("user", "assistant")
|
||||
dl = TranscriptDownload(
|
||||
content=transcript_content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
messages = _msgs("user", "assistant", "user", "assistant", "user")
|
||||
result = extract_context_messages(dl, messages)
|
||||
# 2 transcript messages + 2 gap messages = 4
|
||||
assert len(result) == 4
|
||||
assert result[0].role == "user" # transcript user
|
||||
assert result[1].role == "assistant" # transcript assistant
|
||||
assert result[2].role == "user" # gap user
|
||||
assert result[3].role == "assistant" # gap assistant
|
||||
|
||||
def test_compact_summary_entries_preserved(self):
|
||||
"""``isCompactSummary=True`` entries survive ``_transcript_to_messages``."""
|
||||
import json as stdlib_json
|
||||
|
||||
from .transcript import STOP_REASON_END_TURN
|
||||
|
||||
# Build a transcript where one entry is a compaction summary.
|
||||
# isCompactSummary=True entries have type in STRIPPABLE_TYPES but are kept.
|
||||
compact_entry = stdlib_json.dumps(
|
||||
{
|
||||
"type": "summary",
|
||||
"uuid": "uid-compact",
|
||||
"parentUuid": "",
|
||||
"isCompactSummary": True,
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": "COMPACT_SUMMARY_CONTENT",
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_entry = stdlib_json.dumps(
|
||||
{
|
||||
"type": "assistant",
|
||||
"uuid": "uid-1",
|
||||
"parentUuid": "uid-compact",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_1",
|
||||
"model": "test",
|
||||
"type": "message",
|
||||
"stop_reason": STOP_REASON_END_TURN,
|
||||
"content": [{"type": "text", "text": "response after compact"}],
|
||||
},
|
||||
}
|
||||
)
|
||||
content = compact_entry + "\n" + assistant_entry + "\n"
|
||||
dl = TranscriptDownload(
|
||||
content=content.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
messages = _msgs("user", "assistant", "user")
|
||||
result = extract_context_messages(dl, messages)
|
||||
# Both the compact summary and the assistant response are present
|
||||
assert len(result) == 2
|
||||
roles = [m.role for m in result]
|
||||
assert "user" in roles # compact summary has role=user
|
||||
assert "assistant" in roles
|
||||
# The compact summary content is preserved
|
||||
compact_msgs = [m for m in result if m.role == "user"]
|
||||
assert any("COMPACT_SUMMARY_CONTENT" in (m.content or "") for m in compact_msgs)
|
||||
|
||||
@@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None:
|
||||
print(f"[{sid[:12]}] Not found in GCS")
|
||||
continue
|
||||
|
||||
content_str = (
|
||||
dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content
|
||||
)
|
||||
out = _transcript_path(sid)
|
||||
with open(out, "w") as f:
|
||||
f.write(dl.content)
|
||||
f.write(content_str)
|
||||
|
||||
lines = len(dl.content.strip().split("\n"))
|
||||
lines = len(content_str.strip().split("\n"))
|
||||
meta = {
|
||||
"session_id": sid,
|
||||
"user_id": user_id,
|
||||
"message_count": dl.message_count,
|
||||
"uploaded_at": dl.uploaded_at,
|
||||
"transcript_bytes": len(dl.content),
|
||||
"transcript_bytes": len(content_str),
|
||||
"transcript_lines": lines,
|
||||
}
|
||||
with open(_meta_path(sid), "w") as f:
|
||||
@@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None:
|
||||
|
||||
print(
|
||||
f"[{sid[:12]}] Saved: {lines} entries, "
|
||||
f"{len(dl.content)} bytes, msg_count={dl.message_count}"
|
||||
f"{len(content_str)} bytes, msg_count={dl.message_count}"
|
||||
)
|
||||
print("\nDone. Run 'load' command to import into local dev environment.")
|
||||
|
||||
@@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None:
|
||||
await upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=sid,
|
||||
content=content,
|
||||
content=content.encode("utf-8"),
|
||||
message_count=msg_count,
|
||||
)
|
||||
print(f"[{sid[:12]}] Stored transcript in local workspace storage")
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Unit tests for the transcript watermark (message_count) fix.
|
||||
|
||||
The bug: upload used message_count=len(session.messages) (DB count). When a
|
||||
prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g.
|
||||
covered only T1-T12) but the meta.json watermark matched the full DB count
|
||||
(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1)
|
||||
never triggered, so the model silently lost context for the skipped turns.
|
||||
|
||||
The fix: watermark = previous_coverage + 2 (current user+asst pair) when
|
||||
use_resume=True and transcript_msg_count > 0. This ensures the watermark
|
||||
reflects the JSONL content, not the DB count.
|
||||
|
||||
These tests exercise _build_query_message directly to verify that gap-fill
|
||||
triggers with the corrected watermark but NOT with the inflated (buggy) one.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.sdk.service import _build_query_message
|
||||
|
||||
|
||||
def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]:
|
||||
"""Build a flat list of n_pairs*2 alternating user/asst messages, plus
|
||||
one trailing user message for the *current* turn."""
|
||||
msgs: list[MagicMock] = []
|
||||
for i in range(n_pairs):
|
||||
u = MagicMock()
|
||||
u.role = "user"
|
||||
u.content = f"user message {i}"
|
||||
a = MagicMock()
|
||||
a.role = "assistant"
|
||||
a.content = f"assistant response {i}"
|
||||
msgs.extend([u, a])
|
||||
# Current turn's user message
|
||||
cur = MagicMock()
|
||||
cur.role = "user"
|
||||
cur.content = current_user
|
||||
msgs.append(cur)
|
||||
return msgs
|
||||
|
||||
|
||||
def _make_session(messages: list[MagicMock]) -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.messages = messages
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gap_fill_triggers_for_stale_jsonl():
|
||||
"""Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs).
|
||||
|
||||
With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test').
|
||||
Next turn (T24) downloads watermark=26, DB has 47.
|
||||
Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23.
|
||||
"""
|
||||
# T23 turns in DB (46 messages) + T24 user = 47
|
||||
msgs = _make_messages(23, current_user="memory test - recall all")
|
||||
assert len(msgs) == 47
|
||||
|
||||
session = _make_session(msgs)
|
||||
|
||||
# Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test - recall all",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=26,
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert "<conversation_history>" in result_msg, (
|
||||
"Expected gap-fill to inject <conversation_history> when "
|
||||
"watermark=26 < msg_count-1=46"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_gap_fill_when_watermark_is_current():
|
||||
"""When the JSONL is fully current (watermark = DB-1), no gap injected."""
|
||||
# T23 turns in DB (46 messages) + T24 user = 47
|
||||
msgs = _make_messages(23, current_user="next message")
|
||||
session = _make_session(msgs)
|
||||
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="next message",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=46, # current — no gap
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" not in result_msg
|
||||
), "No gap-fill expected when watermark is current"
|
||||
assert result_msg == "next message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inflated_watermark_suppresses_gap_fill():
|
||||
"""Documents the original bug: inflated watermark suppresses gap-fill.
|
||||
|
||||
'Test' uploaded watermark=len(session.messages)=46 even though only 26
|
||||
messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill.
|
||||
"""
|
||||
msgs = _make_messages(23, current_user="memory test")
|
||||
session = _make_session(msgs)
|
||||
|
||||
# Buggy watermark: inflated to DB count
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=46, # inflated — suppresses gap fill
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" not in result_msg
|
||||
), "With inflated watermark, gap-fill is suppressed — this documents the bug"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixed_watermark_fills_same_gap():
|
||||
"""Same scenario but with the FIXED watermark triggers gap-fill."""
|
||||
msgs = _make_messages(23, current_user="memory test")
|
||||
session = _make_session(msgs)
|
||||
|
||||
result_msg, _ = await _build_query_message(
|
||||
current_message="memory test",
|
||||
session=session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=26, # fixed watermark
|
||||
session_id="test-session-id",
|
||||
)
|
||||
|
||||
assert (
|
||||
"<conversation_history>" in result_msg
|
||||
), "With fixed watermark=26, gap-fill triggers and injects missing turns"
|
||||
Reference in New Issue
Block a user