refactor(backend/copilot): unify transcript API — TranscriptDownload, TranscriptMode, detect_gap, baseline gap-fill

- Rename CliSessionRestore → TranscriptDownload; add mode: TranscriptMode field
- Add TranscriptMode = Literal["sdk", "baseline"] — persisted in .meta.json
- Rename upload_cli_session → upload_transcript (mode param)
- Rename restore_cli_session → download_transcript (reads mode from meta)
- Add detect_gap(download, session_messages) shared helper
- SDK: skip --resume when transcript mode != "sdk" (baseline-written JSONL)
- Baseline: fill gap via _append_gap_to_builder instead of discarding stale transcript
- Remove all backward-compat aliases; update all test files
This commit is contained in:
Zamil Majdy
2026-04-16 03:11:24 +07:00
parent 95a90b92df
commit d6d4fd5cba
9 changed files with 263 additions and 144 deletions

View File

@@ -66,10 +66,11 @@ from backend.copilot.tracking import track_user_message
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
CliSessionRestore,
restore_cli_session,
TranscriptDownload,
detect_gap,
download_transcript,
strip_for_upload,
upload_cli_session,
upload_transcript,
validate_transcript,
)
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -713,21 +714,61 @@ def should_upload_transcript(
return bool(user_id) and transcript_covers_prefix
def _append_gap_to_builder(
gap: list[ChatMessage],
builder: TranscriptBuilder,
) -> None:
"""Append gap messages from chat-db into the TranscriptBuilder.
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
"""
import orjson
for msg in gap:
if msg.role == "user":
builder.append_user(msg.content or "")
elif msg.role == "assistant":
content_blocks: list[dict] = []
if msg.content:
content_blocks.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
try:
input_data = orjson.loads(fn.get("arguments", "{}"))
except Exception:
input_data = {}
content_blocks.append({
"type": "tool_use",
"id": tc.get("id", "") if isinstance(tc, dict) else "",
"name": fn.get("name", "unknown"),
"input": input_data,
})
if content_blocks:
builder.append_assistant(content_blocks=content_blocks)
elif msg.role == "tool" and msg.tool_call_id:
builder.append_tool_result(
tool_use_id=msg.tool_call_id,
content=msg.content or "",
)
async def _load_prior_transcript(
user_id: str,
session_id: str,
session_msg_count: int,
session_messages: list[ChatMessage],
transcript_builder: TranscriptBuilder,
) -> bool:
"""Download and load the prior CLI session into ``transcript_builder``.
Returns ``True`` when the loaded session fully covers the session
prefix; ``False`` otherwise (stale, missing, invalid, or download
error). Callers should suppress uploads when this returns ``False``
to avoid overwriting a more complete version in storage.
prefix; ``False`` otherwise (missing, invalid, or download error).
Callers should suppress uploads when this returns ``False`` to avoid
overwriting a more complete version in storage.
"""
try:
restore = await restore_cli_session(
restore = await download_transcript(
user_id, session_id, log_prefix="[Baseline]"
)
except Exception as e:
@@ -749,20 +790,22 @@ async def _load_prior_transcript(
logger.warning("[Baseline] CLI session content invalid after strip")
return False
if restore.message_count > 0 and restore.message_count < session_msg_count - 1:
logger.warning(
"[Baseline] Session stale: covers %d of %d messages, skipping",
restore.message_count,
session_msg_count,
)
return False
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
logger.info(
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
len(restore.content),
restore.message_count,
)
gap = detect_gap(restore, session_messages)
if gap:
_append_gap_to_builder(gap, transcript_builder)
logger.info(
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
restore.message_count,
len(gap),
)
return True
@@ -794,11 +837,12 @@ async def _upload_final_transcript(
# orphaned coroutine; shield it so cancellation of this caller doesn't
# abort the in-flight GCS write.
upload_task = asyncio.create_task(
upload_cli_session(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content.encode("utf-8"),
message_count=session_msg_count,
mode="baseline",
log_prefix="[Baseline]",
)
)
@@ -911,7 +955,7 @@ async def stream_chat_completion_baseline(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
session_messages=session.messages,
transcript_builder=transcript_builder,
),
prompt_task,

View File

@@ -18,11 +18,12 @@ from backend.copilot.baseline.service import (
_upload_final_transcript,
should_upload_transcript,
)
from backend.copilot.model import ChatMessage
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
CliSessionRestore,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
@@ -53,6 +54,11 @@ def _make_transcript_content(*roles: str) -> str:
return "\n".join(lines) + "\n"
def _make_session_messages(*roles: str) -> list[ChatMessage]:
"""Build a list of ChatMessage objects matching the given roles."""
return [ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)]
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
@@ -78,16 +84,16 @@ class TestLoadPriorTranscript:
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
restore = CliSessionRestore(content=content.encode("utf-8"), message_count=2)
restore = TranscriptDownload(content=content.encode("utf-8"), message_count=2, mode="sdk")
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
@@ -96,38 +102,39 @@ class TestLoadPriorTranscript:
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
async def test_fills_gap_when_transcript_is_behind(self):
"""When transcript covers fewer messages than session, gap is filled from DB."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
restore = CliSessionRestore(content=content.encode("utf-8"), message_count=2)
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
restore = TranscriptDownload(content=content.encode("utf-8"), message_count=2, mode="baseline")
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
session_messages=_make_session_messages("user", "assistant", "user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
assert covers is True
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
@@ -137,18 +144,19 @@ class TestLoadPriorTranscript:
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
builder = TranscriptBuilder()
restore = CliSessionRestore(
restore = TranscriptDownload(
content=b'{"type":"progress","uuid":"a"}\n',
message_count=1,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
@@ -159,13 +167,13 @@ class TestLoadPriorTranscript:
async def test_download_exception_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
session_messages=_make_session_messages("user", "assistant"),
transcript_builder=builder,
)
@@ -174,20 +182,21 @@ class TestLoadPriorTranscript:
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
"""When msg_count is 0 (unknown), gap detection is skipped."""
builder = TranscriptBuilder()
restore = CliSessionRestore(
restore = TranscriptDownload(
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=0,
mode="sdk",
)
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
session_messages=_make_session_messages(*["user"] * 20),
transcript_builder=builder,
)
@@ -210,7 +219,7 @@ class TestUploadFinalTranscript:
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
@@ -233,7 +242,7 @@ class TestUploadFinalTranscript:
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
@@ -257,7 +266,7 @@ class TestUploadFinalTranscript:
)
with patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
):
# Should not raise.
@@ -373,17 +382,17 @@ class TestRoundTrip:
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
restore = CliSessionRestore(content=prior.encode("utf-8"), message_count=2)
restore = TranscriptDownload(content=prior.encode("utf-8"), message_count=2, mode="sdk")
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -410,7 +419,7 @@ class TestRoundTrip:
# Upload.
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
@@ -491,16 +500,16 @@ class TestTranscriptLifecycle:
"""Fresh restore, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
restore = CliSessionRestore(content=prior.encode("utf-8"), message_count=2)
restore = TranscriptDownload(content=prior.encode("utf-8"), message_count=2, mode="sdk")
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
),
patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
@@ -508,7 +517,7 @@ class TestTranscriptLifecycle:
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
session_messages=_make_session_messages("user", "assistant", "user"),
transcript_builder=builder,
)
assert covers is True
@@ -550,40 +559,39 @@ class TestTranscriptLifecycle:
assert b"assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale restore → covers=False → upload must be skipped."""
async def test_lifecycle_stale_download_fills_gap(self):
"""When transcript covers fewer messages, gap is filled rather than rejected."""
builder = TranscriptBuilder()
# session has 10 msgs but stored session only covers 2 → stale.
stale = CliSessionRestore(
# session has 5 msgs but stored transcript only covers 2 → gap filled.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant").encode("utf-8"),
message_count=2,
mode="baseline",
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=stale),
),
patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=_make_session_messages(
"user", "assistant", "user", "assistant", "user"
),
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
assert covers is True
# Gap was filled: 2 from transcript + 2 gap messages
assert builder.entry_count == 4
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
@@ -609,18 +617,18 @@ class TestTranscriptLifecycle:
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.baseline.service.upload_cli_session",
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
session_messages=_make_session_messages("user"),
transcript_builder=builder,
)
# No restore: covers is False, so the production path would

View File

@@ -252,9 +252,10 @@ class TestSdkToFastModeSwitch:
async def test_scenario_s_baseline_loads_sdk_transcript(self):
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
CliSessionRestore,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -270,19 +271,23 @@ class TestSdkToFastModeSwitch:
sdk_transcript = builder_sdk.to_jsonl()
# Baseline session now has those 2 SDK messages + 1 new baseline message.
restore = CliSessionRestore(
content=sdk_transcript.encode("utf-8"), message_count=2
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3, # 2 SDK + 1 new baseline
session_messages=[
ChatMessage(role="user", content="sdk-question"),
ChatMessage(role="assistant", content="sdk-answer"),
ChatMessage(role="user", content="baseline-question"),
],
transcript_builder=baseline_builder,
)
@@ -299,9 +304,10 @@ class TestSdkToFastModeSwitch:
to avoid injecting an incomplete history.
"""
from backend.copilot.baseline.service import _load_prior_transcript
from backend.copilot.model import ChatMessage
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
CliSessionRestore,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
@@ -315,22 +321,31 @@ class TestSdkToFastModeSwitch:
sdk_transcript = builder_sdk.to_jsonl()
# Session covers only 2 messages but session has 10 (many SDK turns).
restore = CliSessionRestore(
content=sdk_transcript.encode("utf-8"), message_count=2
# With watermark=2 and 10 total messages, detect_gap will fill the gap
# by appending messages 2..8 (positions 2 to total-2).
restore = TranscriptDownload(
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
)
# Build a session with 10 alternating user/assistant messages + current user
session_messages = [
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
for i in range(10)
]
baseline_builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.restore_cli_session",
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=restore),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
session_messages=session_messages,
transcript_builder=baseline_builder,
)
# Stale session must be rejected.
assert covers is False
assert baseline_builder.is_empty
# With gap filling, covers is True and gap messages are appended.
assert covers is True
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
assert baseline_builder.entry_count == 9

View File

@@ -27,7 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.transcript import (
CliSessionRestore,
TranscriptDownload,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
@@ -997,16 +997,16 @@ def _make_sdk_patches(
dict(new_callable=AsyncMock, return_value=("system prompt", None)),
),
(
f"{_SVC}.restore_cli_session",
f"{_SVC}.download_transcript",
dict(
new_callable=AsyncMock,
return_value=CliSessionRestore(
content=original_transcript.encode("utf-8"), message_count=2
return_value=TranscriptDownload(
content=original_transcript.encode("utf-8"), message_count=2, mode="sdk"
),
),
),
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
(f"{_SVC}.validate_transcript", dict(return_value=True)),
(
f"{_SVC}.compact_transcript",
@@ -1913,14 +1913,14 @@ class TestStreamChatCompletionRetryIntegration:
compacted_transcript=None,
client_side_effect=_client_factory,
)
# Override restore_cli_session to return None (CLI native session unavailable)
# Override download_transcript to return None (CLI native session unavailable)
patches = [
(
(
f"{_SVC}.restore_cli_session",
f"{_SVC}.download_transcript",
dict(new_callable=AsyncMock, return_value=None),
)
if p[0] == f"{_SVC}.restore_cli_session"
if p[0] == f"{_SVC}.download_transcript"
else p
)
for p in patches
@@ -1943,7 +1943,7 @@ class TestStreamChatCompletionRetryIntegration:
# captured_options holds {"options": ClaudeAgentOptions}, so check
# the attribute directly rather than dict keys.
assert not getattr(captured_options.get("options"), "resume", None), (
f"--resume was set even though restore_cli_session returned False: "
f"--resume was set even though download_transcript returned None: "
f"{captured_options}"
)
assert any(isinstance(e, StreamStart) for e in events)

View File

@@ -93,15 +93,17 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from ..transcript import (
_run_compression,
CliSessionRestore,
TranscriptDownload,
TranscriptMode,
cleanup_stale_project_dirs,
cli_session_path,
compact_transcript,
detect_gap,
download_transcript,
projects_base,
read_compacted_entries,
restore_cli_session,
strip_for_upload,
upload_cli_session,
upload_transcript,
validate_transcript,
)
from ..transcript_builder import TranscriptBuilder
@@ -947,7 +949,7 @@ def _read_cli_session_from_disk(
def _process_cli_restore(
cli_restore: CliSessionRestore,
cli_restore: TranscriptDownload,
sdk_cwd: str,
session_id: str,
log_prefix: str,
@@ -2600,7 +2602,7 @@ async def stream_chat_completion_sdk(
transcript_msg_count = 0
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
try:
cli_restore = await restore_cli_session(
cli_restore = await download_transcript(
user_id, session_id, log_prefix=log_prefix
)
except Exception as restore_err:
@@ -2611,6 +2613,17 @@ async def stream_chat_completion_sdk(
)
cli_restore = None
# Only attempt --resume for SDK-written transcripts.
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
# stripped fields) that may not be valid for --resume.
if cli_restore is not None and cli_restore.mode != "sdk":
logger.info(
"%s Transcript written by mode=%r, skipping --resume — will reconstruct from DB",
log_prefix,
cli_restore.mode,
)
cli_restore = None
# Validate, strip, and write to disk — delegate to helper to reduce
# function complexity. Writing an invalid/corrupt file to disk then
# falling back to "no --resume" would cause the CLI to fail with
@@ -3529,11 +3542,12 @@ async def stream_chat_completion_sdk(
)
if _cli_content:
await asyncio.shield(
upload_cli_session(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=_cli_content,
message_count=len(session.messages),
mode="sdk",
log_prefix=log_prefix,
)
)

View File

@@ -12,16 +12,18 @@ from backend.copilot.transcript import (
ENTRY_TYPE_MESSAGE,
STOP_REASON_END_TURN,
STRIPPABLE_TYPES,
CliSessionRestore,
TranscriptDownload,
TranscriptMode,
cleanup_stale_project_dirs,
compact_transcript,
delete_transcript,
detect_gap,
download_transcript,
read_compacted_entries,
restore_cli_session,
strip_for_upload,
strip_progress_entries,
strip_stale_thinking_blocks,
upload_cli_session,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
@@ -31,16 +33,18 @@ __all__ = [
"ENTRY_TYPE_MESSAGE",
"STOP_REASON_END_TURN",
"STRIPPABLE_TYPES",
"CliSessionRestore",
"TranscriptDownload",
"TranscriptMode",
"cleanup_stale_project_dirs",
"compact_transcript",
"delete_transcript",
"detect_gap",
"download_transcript",
"read_compacted_entries",
"restore_cli_session",
"strip_for_upload",
"strip_progress_entries",
"strip_stale_thinking_blocks",
"upload_cli_session",
"upload_transcript",
"validate_transcript",
"write_transcript_to_tempfile",
]

View File

@@ -7,7 +7,7 @@ import pytest
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta
from .sdk import service as sdk_service
from .transcript import restore_cli_session
from .transcript import download_transcript
logger = logging.getLogger(__name__)
@@ -64,7 +64,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
cli_session = None
for _ in range(10):
await asyncio.sleep(0.5)
cli_session = await restore_cli_session(test_user_id, session.session_id)
cli_session = await download_transcript(test_user_id, session.session_id)
if cli_session:
break
if not cli_session:

View File

@@ -20,6 +20,7 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from uuid import uuid4
from backend.util import json
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
if TYPE_CHECKING:
from backend.copilot.model import ChatMessage
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -44,12 +48,14 @@ STRIPPABLE_TYPES = frozenset(
)
@dataclass
class CliSessionRestore:
"""Result of restoring the CLI native session file."""
TranscriptMode = Literal["sdk", "baseline"]
content: bytes # raw bytes written to disk (for builder seeding)
message_count: int = 0 # watermark from companion .meta.json
@dataclass
class TranscriptDownload:
content: bytes
message_count: int = 0
mode: TranscriptMode = "sdk" # "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
@@ -661,11 +667,12 @@ def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, st
)
async def upload_cli_session(
async def upload_transcript(
user_id: str,
session_id: str,
content: bytes,
message_count: int = 0,
mode: TranscriptMode = "sdk",
log_prefix: str = "[Transcript]",
) -> None:
"""Upload CLI session content to GCS with companion meta.json.
@@ -674,7 +681,7 @@ async def upload_cli_session(
the session file from disk before calling this function.
Also uploads a companion .meta.json with the message_count watermark so
restore_cli_session can return it without a separate fetch.
download_transcript can return it without a separate fetch.
Called after each turn so the next turn can restore the file on any pod
(eliminating the pod-affinity requirement for --resume).
@@ -682,7 +689,7 @@ async def upload_cli_session(
storage = await get_workspace_storage()
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
meta = {"message_count": message_count, "uploaded_at": time.time()}
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
meta_encoded = json.dumps(meta).encode("utf-8")
session_result, meta_result = await asyncio.gather(
@@ -709,18 +716,18 @@ async def upload_cli_session(
)
async def restore_cli_session(
async def download_transcript(
user_id: str,
session_id: str,
log_prefix: str = "[Transcript]",
) -> CliSessionRestore | None:
"""Download CLI session from GCS. Returns content + message_count, or None if not found.
) -> TranscriptDownload | None:
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
Pure GCS operation — no disk I/O. The caller is responsible for writing
content to disk if --resume is needed.
Returns a CliSessionRestore with the raw content and message_count watermark
on success, or None if not available (first turn or upload failed).
Returns a TranscriptDownload with the raw content, message_count watermark,
and mode on success, or None if not available (first turn or upload failed).
"""
storage = await get_workspace_storage()
path = _build_path_from_parts(
@@ -747,15 +754,18 @@ async def restore_cli_session(
content: bytes = content_result
# Parse message_count from companion meta — best-effort, default to 0.
# Parse message_count and mode from companion meta — best-effort, defaults.
message_count = 0
mode: TranscriptMode = "sdk"
if isinstance(meta_result, FileNotFoundError):
pass # No meta — first upload or old version; default to 0
pass # No meta — old upload; default to "sdk"
elif isinstance(meta_result, BaseException):
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
else:
meta = json.loads(meta_result.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
raw_mode = meta.get("mode", "sdk")
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
logger.info(
"%s Downloaded CLI session (%dB, msg_count=%d)",
@@ -763,7 +773,29 @@ async def restore_cli_session(
len(content),
message_count,
)
return CliSessionRestore(content=content, message_count=message_count)
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
def detect_gap(
download: TranscriptDownload,
session_messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Return chat-db messages after the transcript watermark (excluding current user turn).
Returns [] if transcript is current, watermark is zero, or the watermark
position doesn't end on an assistant turn (misaligned watermark).
"""
if download.message_count == 0:
return []
wm = download.message_count
total = len(session_messages)
if wm >= total - 1:
return []
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
if session_messages[wm - 1].role != "assistant":
return []
return list(session_messages[wm : total - 1])
async def delete_transcript(user_id: str, session_id: str) -> None:

View File

@@ -10,7 +10,7 @@ from unittest.mock import MagicMock
from backend.util import json
from .transcript import (
CliSessionRestore,
TranscriptDownload,
_build_path_from_parts,
_find_last_assistant_entry,
_flatten_assistant_content,
@@ -720,7 +720,7 @@ class TestUploadCliSession:
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import upload_cli_session
from .transcript import upload_transcript
mock_storage = AsyncMock()
content = b'{"type":"assistant"}\n'
@@ -731,7 +731,7 @@ class TestUploadCliSession:
return_value=mock_storage,
):
asyncio.run(
upload_cli_session(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000001",
content=content,
@@ -742,12 +742,12 @@ class TestUploadCliSession:
assert mock_storage.store.call_count == 2
def test_uploads_companion_meta_json_with_message_count(self):
"""upload_cli_session stores a companion .meta.json with message_count."""
"""upload_transcript stores a companion .meta.json with message_count."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import upload_cli_session
from .transcript import upload_transcript
mock_storage = AsyncMock()
content = b'{"type":"assistant"}\n'
@@ -758,7 +758,7 @@ class TestUploadCliSession:
return_value=mock_storage,
):
asyncio.run(
upload_cli_session(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000010",
content=content,
@@ -781,7 +781,7 @@ class TestUploadCliSession:
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import upload_cli_session
from .transcript import upload_transcript
mock_storage = AsyncMock()
mock_storage.store.side_effect = [RuntimeError("gcs unavailable"), None]
@@ -794,7 +794,7 @@ class TestUploadCliSession:
):
# Should not raise — failures are logged as warnings
asyncio.run(
upload_cli_session(
upload_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000002",
content=content,
@@ -810,7 +810,7 @@ class TestRestoreCliSession:
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [
@@ -824,7 +824,7 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000000",
)
@@ -832,12 +832,12 @@ class TestRestoreCliSession:
assert result is None
def test_returns_cli_session_restore_on_success_no_meta(self):
"""Happy path with no meta.json: returns CliSessionRestore with message_count=0."""
def test_returns_transcript_download_on_success_no_meta(self):
"""Happy path with no meta.json: returns TranscriptDownload with message_count=0."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
session_id = "12345678-0000-0000-0000-000000000003"
content = b'{"type":"assistant"}\n'
@@ -851,27 +851,28 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id=session_id,
)
)
assert isinstance(result, CliSessionRestore)
assert isinstance(result, TranscriptDownload)
assert result.content == content
assert result.message_count == 0
assert result.mode == "sdk"
def test_returns_cli_session_restore_with_message_count_from_meta(self):
"""When meta.json is present, message_count is read from it."""
def test_returns_transcript_download_with_message_count_from_meta(self):
"""When meta.json is present, message_count and mode are read from it."""
import asyncio
import json
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
session_id = "12345678-0000-0000-0000-000000000005"
content = b'{"type":"assistant"}\n'
meta_bytes = json.dumps({"message_count": 7, "uploaded_at": 1234567.0}).encode()
meta_bytes = json.dumps({"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0}).encode()
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [content, meta_bytes]
@@ -882,22 +883,23 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id=session_id,
)
)
assert isinstance(result, CliSessionRestore)
assert isinstance(result, TranscriptDownload)
assert result.content == content
assert result.message_count == 7
assert result.mode == "sdk"
def test_returns_none_on_download_exception(self):
"""Non-FileNotFoundError during retrieve logs warning and returns None."""
import asyncio
from unittest.mock import AsyncMock, patch
from .transcript import restore_cli_session
from .transcript import download_transcript
mock_storage = AsyncMock()
mock_storage.retrieve.side_effect = [
@@ -911,7 +913,7 @@ class TestRestoreCliSession:
return_value=mock_storage,
):
result = asyncio.run(
restore_cli_session(
download_transcript(
user_id="user-1",
session_id="12345678-0000-0000-0000-000000000004",
)