mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
refactor(backend/copilot): unify transcript API — TranscriptDownload, TranscriptMode, detect_gap, baseline gap-fill
- Rename CliSessionRestore → TranscriptDownload; add mode: TranscriptMode field - Add TranscriptMode = Literal["sdk", "baseline"] — persisted in .meta.json - Rename upload_cli_session → upload_transcript (mode param) - Rename restore_cli_session → download_transcript (reads mode from meta) - Add detect_gap(download, session_messages) shared helper - SDK: skip --resume when transcript mode != "sdk" (baseline-written JSONL) - Baseline: fill gap via _append_gap_to_builder instead of discarding stale transcript - Remove all backward-compat aliases; update all test files
This commit is contained in:
@@ -66,10 +66,11 @@ from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
CliSessionRestore,
|
||||
restore_cli_session,
|
||||
TranscriptDownload,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
strip_for_upload,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
@@ -713,21 +714,61 @@ def should_upload_transcript(
|
||||
return bool(user_id) and transcript_covers_prefix
|
||||
|
||||
|
||||
def _append_gap_to_builder(
|
||||
gap: list[ChatMessage],
|
||||
builder: TranscriptBuilder,
|
||||
) -> None:
|
||||
"""Append gap messages from chat-db into the TranscriptBuilder.
|
||||
|
||||
Converts ChatMessage (OpenAI format) to TranscriptBuilder entries
|
||||
(Claude CLI JSONL format) so the uploaded transcript covers all turns.
|
||||
"""
|
||||
import orjson
|
||||
|
||||
for msg in gap:
|
||||
if msg.role == "user":
|
||||
builder.append_user(msg.content or "")
|
||||
elif msg.role == "assistant":
|
||||
content_blocks: list[dict] = []
|
||||
if msg.content:
|
||||
content_blocks.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
try:
|
||||
input_data = orjson.loads(fn.get("arguments", "{}"))
|
||||
except Exception:
|
||||
input_data = {}
|
||||
content_blocks.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", "") if isinstance(tc, dict) else "",
|
||||
"name": fn.get("name", "unknown"),
|
||||
"input": input_data,
|
||||
})
|
||||
if content_blocks:
|
||||
builder.append_assistant(content_blocks=content_blocks)
|
||||
elif msg.role == "tool" and msg.tool_call_id:
|
||||
builder.append_tool_result(
|
||||
tool_use_id=msg.tool_call_id,
|
||||
content=msg.content or "",
|
||||
)
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
session_msg_count: int,
|
||||
session_messages: list[ChatMessage],
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> bool:
|
||||
"""Download and load the prior CLI session into ``transcript_builder``.
|
||||
|
||||
Returns ``True`` when the loaded session fully covers the session
|
||||
prefix; ``False`` otherwise (stale, missing, invalid, or download
|
||||
error). Callers should suppress uploads when this returns ``False``
|
||||
to avoid overwriting a more complete version in storage.
|
||||
prefix; ``False`` otherwise (missing, invalid, or download error).
|
||||
Callers should suppress uploads when this returns ``False`` to avoid
|
||||
overwriting a more complete version in storage.
|
||||
"""
|
||||
try:
|
||||
restore = await restore_cli_session(
|
||||
restore = await download_transcript(
|
||||
user_id, session_id, log_prefix="[Baseline]"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -749,20 +790,22 @@ async def _load_prior_transcript(
|
||||
logger.warning("[Baseline] CLI session content invalid after strip")
|
||||
return False
|
||||
|
||||
if restore.message_count > 0 and restore.message_count < session_msg_count - 1:
|
||||
logger.warning(
|
||||
"[Baseline] Session stale: covers %d of %d messages, skipping",
|
||||
restore.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
return False
|
||||
|
||||
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded CLI session: %dB, msg_count=%d",
|
||||
len(restore.content),
|
||||
restore.message_count,
|
||||
)
|
||||
|
||||
gap = detect_gap(restore, session_messages)
|
||||
if gap:
|
||||
_append_gap_to_builder(gap, transcript_builder)
|
||||
logger.info(
|
||||
"[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB",
|
||||
restore.message_count,
|
||||
len(gap),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -794,11 +837,12 @@ async def _upload_final_transcript(
|
||||
# orphaned coroutine; shield it so cancellation of this caller doesn't
|
||||
# abort the in-flight GCS write.
|
||||
upload_task = asyncio.create_task(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content.encode("utf-8"),
|
||||
message_count=session_msg_count,
|
||||
mode="baseline",
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
)
|
||||
@@ -911,7 +955,7 @@ async def stream_chat_completion_baseline(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
session_messages=session.messages,
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
|
||||
@@ -18,11 +18,12 @@ from backend.copilot.baseline.service import (
|
||||
_upload_final_transcript,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
@@ -53,6 +54,11 @@ def _make_transcript_content(*roles: str) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _make_session_messages(*roles: str) -> list[ChatMessage]:
|
||||
"""Build a list of ChatMessage objects matching the given roles."""
|
||||
return [ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles)]
|
||||
|
||||
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
@@ -78,16 +84,16 @@ class TestLoadPriorTranscript:
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
restore = CliSessionRestore(content=content.encode("utf-8"), message_count=2)
|
||||
restore = TranscriptDownload(content=content.encode("utf-8"), message_count=2, mode="sdk")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
@@ -96,38 +102,39 @@ class TestLoadPriorTranscript:
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
async def test_fills_gap_when_transcript_is_behind(self):
|
||||
"""When transcript covers fewer messages than session, gap is filled from DB."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
restore = CliSessionRestore(content=content.encode("utf-8"), message_count=2)
|
||||
# transcript covers 2 messages, session has 4 (plus current user turn = 5)
|
||||
restore = TranscriptDownload(content=content.encode("utf-8"), message_count=2, mode="baseline")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
session_messages=_make_session_messages("user", "assistant", "user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
assert covers is True
|
||||
# 2 from transcript + 2 gap messages (user+assistant at positions 2,3)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
@@ -137,18 +144,19 @@ class TestLoadPriorTranscript:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
restore = CliSessionRestore(
|
||||
restore = TranscriptDownload(
|
||||
content=b'{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
@@ -159,13 +167,13 @@ class TestLoadPriorTranscript:
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
@@ -174,20 +182,21 @@ class TestLoadPriorTranscript:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
"""When msg_count is 0 (unknown), gap detection is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
restore = CliSessionRestore(
|
||||
restore = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=0,
|
||||
mode="sdk",
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
session_messages=_make_session_messages(*["user"] * 20),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
@@ -210,7 +219,7 @@ class TestUploadFinalTranscript:
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
@@ -233,7 +242,7 @@ class TestUploadFinalTranscript:
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
@@ -257,7 +266,7 @@ class TestUploadFinalTranscript:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
@@ -373,17 +382,17 @@ class TestRoundTrip:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
restore = CliSessionRestore(content=prior.encode("utf-8"), message_count=2)
|
||||
restore = TranscriptDownload(content=prior.encode("utf-8"), message_count=2, mode="sdk")
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -410,7 +419,7 @@ class TestRoundTrip:
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
@@ -491,16 +500,16 @@ class TestTranscriptLifecycle:
|
||||
"""Fresh restore, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
restore = CliSessionRestore(content=prior.encode("utf-8"), message_count=2)
|
||||
restore = TranscriptDownload(content=prior.encode("utf-8"), message_count=2, mode="sdk")
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
@@ -508,7 +517,7 @@ class TestTranscriptLifecycle:
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
session_messages=_make_session_messages("user", "assistant", "user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
@@ -550,40 +559,39 @@ class TestTranscriptLifecycle:
|
||||
assert b"assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale restore → covers=False → upload must be skipped."""
|
||||
async def test_lifecycle_stale_download_fills_gap(self):
|
||||
"""When transcript covers fewer messages, gap is filled rather than rejected."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored session only covers 2 → stale.
|
||||
stale = CliSessionRestore(
|
||||
# session has 5 msgs but stored transcript only covers 2 → gap filled.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant").encode("utf-8"),
|
||||
message_count=2,
|
||||
mode="baseline",
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=_make_session_messages(
|
||||
"user", "assistant", "user", "assistant", "user"
|
||||
),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
assert covers is True
|
||||
# Gap was filled: 2 from transcript + 2 gap messages
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
@@ -609,18 +617,18 @@ class TestTranscriptLifecycle:
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_cli_session",
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
session_messages=_make_session_messages("user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No restore: covers is False, so the production path would
|
||||
|
||||
@@ -252,9 +252,10 @@ class TestSdkToFastModeSwitch:
|
||||
async def test_scenario_s_baseline_loads_sdk_transcript(self):
|
||||
"""Scenario S: SDK-written CLI session is accepted by baseline's load helper."""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -270,19 +271,23 @@ class TestSdkToFastModeSwitch:
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Baseline session now has those 2 SDK messages + 1 new baseline message.
|
||||
restore = CliSessionRestore(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3, # 2 SDK + 1 new baseline
|
||||
session_messages=[
|
||||
ChatMessage(role="user", content="sdk-question"),
|
||||
ChatMessage(role="assistant", content="sdk-answer"),
|
||||
ChatMessage(role="user", content="baseline-question"),
|
||||
],
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
@@ -299,9 +304,10 @@ class TestSdkToFastModeSwitch:
|
||||
to avoid injecting an incomplete history.
|
||||
"""
|
||||
from backend.copilot.baseline.service import _load_prior_transcript
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
|
||||
@@ -315,22 +321,31 @@ class TestSdkToFastModeSwitch:
|
||||
sdk_transcript = builder_sdk.to_jsonl()
|
||||
|
||||
# Session covers only 2 messages but session has 10 (many SDK turns).
|
||||
restore = CliSessionRestore(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2
|
||||
# With watermark=2 and 10 total messages, detect_gap will fill the gap
|
||||
# by appending messages 2..8 (positions 2 to total-2).
|
||||
restore = TranscriptDownload(
|
||||
content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
)
|
||||
|
||||
# Build a session with 10 alternating user/assistant messages + current user
|
||||
session_messages = [
|
||||
ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
baseline_builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.restore_cli_session",
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
session_messages=session_messages,
|
||||
transcript_builder=baseline_builder,
|
||||
)
|
||||
|
||||
# Stale session must be rejected.
|
||||
assert covers is False
|
||||
assert baseline_builder.is_empty
|
||||
# With gap filling, covers is True and gap messages are appended.
|
||||
assert covers is True
|
||||
# 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn)
|
||||
assert baseline_builder.entry_count == 9
|
||||
|
||||
@@ -27,7 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from backend.copilot.transcript import (
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
@@ -997,16 +997,16 @@ def _make_sdk_patches(
|
||||
dict(new_callable=AsyncMock, return_value=("system prompt", None)),
|
||||
),
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(
|
||||
new_callable=AsyncMock,
|
||||
return_value=CliSessionRestore(
|
||||
content=original_transcript.encode("utf-8"), message_count=2
|
||||
return_value=TranscriptDownload(
|
||||
content=original_transcript.encode("utf-8"), message_count=2, mode="sdk"
|
||||
),
|
||||
),
|
||||
),
|
||||
(f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)),
|
||||
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.validate_transcript", dict(return_value=True)),
|
||||
(
|
||||
f"{_SVC}.compact_transcript",
|
||||
@@ -1913,14 +1913,14 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
compacted_transcript=None,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
# Override restore_cli_session to return None (CLI native session unavailable)
|
||||
# Override download_transcript to return None (CLI native session unavailable)
|
||||
patches = [
|
||||
(
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
f"{_SVC}.download_transcript",
|
||||
dict(new_callable=AsyncMock, return_value=None),
|
||||
)
|
||||
if p[0] == f"{_SVC}.restore_cli_session"
|
||||
if p[0] == f"{_SVC}.download_transcript"
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
@@ -1943,7 +1943,7 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
# captured_options holds {"options": ClaudeAgentOptions}, so check
|
||||
# the attribute directly rather than dict keys.
|
||||
assert not getattr(captured_options.get("options"), "resume", None), (
|
||||
f"--resume was set even though restore_cli_session returned False: "
|
||||
f"--resume was set even though download_transcript returned None: "
|
||||
f"{captured_options}"
|
||||
)
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -93,15 +93,17 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from ..transcript import (
|
||||
_run_compression,
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
TranscriptMode,
|
||||
cleanup_stale_project_dirs,
|
||||
cli_session_path,
|
||||
compact_transcript,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
projects_base,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
strip_for_upload,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from ..transcript_builder import TranscriptBuilder
|
||||
@@ -947,7 +949,7 @@ def _read_cli_session_from_disk(
|
||||
|
||||
|
||||
def _process_cli_restore(
|
||||
cli_restore: CliSessionRestore,
|
||||
cli_restore: TranscriptDownload,
|
||||
sdk_cwd: str,
|
||||
session_id: str,
|
||||
log_prefix: str,
|
||||
@@ -2600,7 +2602,7 @@ async def stream_chat_completion_sdk(
|
||||
transcript_msg_count = 0
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
try:
|
||||
cli_restore = await restore_cli_session(
|
||||
cli_restore = await download_transcript(
|
||||
user_id, session_id, log_prefix=log_prefix
|
||||
)
|
||||
except Exception as restore_err:
|
||||
@@ -2611,6 +2613,17 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
cli_restore = None
|
||||
|
||||
# Only attempt --resume for SDK-written transcripts.
|
||||
# Baseline-written transcripts use TranscriptBuilder format (synthetic IDs,
|
||||
# stripped fields) that may not be valid for --resume.
|
||||
if cli_restore is not None and cli_restore.mode != "sdk":
|
||||
logger.info(
|
||||
"%s Transcript written by mode=%r, skipping --resume — will reconstruct from DB",
|
||||
log_prefix,
|
||||
cli_restore.mode,
|
||||
)
|
||||
cli_restore = None
|
||||
|
||||
# Validate, strip, and write to disk — delegate to helper to reduce
|
||||
# function complexity. Writing an invalid/corrupt file to disk then
|
||||
# falling back to "no --resume" would cause the CLI to fail with
|
||||
@@ -3529,11 +3542,12 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
if _cli_content:
|
||||
await asyncio.shield(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_cli_content,
|
||||
message_count=len(session.messages),
|
||||
mode="sdk",
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -12,16 +12,18 @@ from backend.copilot.transcript import (
|
||||
ENTRY_TYPE_MESSAGE,
|
||||
STOP_REASON_END_TURN,
|
||||
STRIPPABLE_TYPES,
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
TranscriptMode,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
delete_transcript,
|
||||
detect_gap,
|
||||
download_transcript,
|
||||
read_compacted_entries,
|
||||
restore_cli_session,
|
||||
strip_for_upload,
|
||||
strip_progress_entries,
|
||||
strip_stale_thinking_blocks,
|
||||
upload_cli_session,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
@@ -31,16 +33,18 @@ __all__ = [
|
||||
"ENTRY_TYPE_MESSAGE",
|
||||
"STOP_REASON_END_TURN",
|
||||
"STRIPPABLE_TYPES",
|
||||
"CliSessionRestore",
|
||||
"TranscriptDownload",
|
||||
"TranscriptMode",
|
||||
"cleanup_stale_project_dirs",
|
||||
"compact_transcript",
|
||||
"delete_transcript",
|
||||
"detect_gap",
|
||||
"download_transcript",
|
||||
"read_compacted_entries",
|
||||
"restore_cli_session",
|
||||
"strip_for_upload",
|
||||
"strip_progress_entries",
|
||||
"strip_stale_thinking_blocks",
|
||||
"upload_cli_session",
|
||||
"upload_transcript",
|
||||
"validate_transcript",
|
||||
"write_transcript_to_tempfile",
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from .model import create_chat_session, get_chat_session, upsert_chat_session
|
||||
from .response_model import StreamError, StreamTextDelta
|
||||
from .sdk import service as sdk_service
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +64,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
cli_session = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
cli_session = await restore_cli_session(test_user_id, session.session_id)
|
||||
cli_session = await download_transcript(test_user_id, session.session_id)
|
||||
if cli_session:
|
||||
break
|
||||
if not cli_session:
|
||||
|
||||
@@ -20,6 +20,7 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.util import json
|
||||
@@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -44,12 +48,14 @@ STRIPPABLE_TYPES = frozenset(
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CliSessionRestore:
|
||||
"""Result of restoring the CLI native session file."""
|
||||
TranscriptMode = Literal["sdk", "baseline"]
|
||||
|
||||
content: bytes # raw bytes written to disk (for builder seeding)
|
||||
message_count: int = 0 # watermark from companion .meta.json
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
content: bytes
|
||||
message_count: int = 0
|
||||
mode: TranscriptMode = "sdk" # "sdk" = Claude CLI native, "baseline" = TranscriptBuilder
|
||||
|
||||
|
||||
# Storage prefix for the CLI's native session JSONL files (for cross-pod --resume).
|
||||
@@ -661,11 +667,12 @@ def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, st
|
||||
)
|
||||
|
||||
|
||||
async def upload_cli_session(
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: bytes,
|
||||
message_count: int = 0,
|
||||
mode: TranscriptMode = "sdk",
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> None:
|
||||
"""Upload CLI session content to GCS with companion meta.json.
|
||||
@@ -674,7 +681,7 @@ async def upload_cli_session(
|
||||
the session file from disk before calling this function.
|
||||
|
||||
Also uploads a companion .meta.json with the message_count watermark so
|
||||
restore_cli_session can return it without a separate fetch.
|
||||
download_transcript can return it without a separate fetch.
|
||||
|
||||
Called after each turn so the next turn can restore the file on any pod
|
||||
(eliminating the pod-affinity requirement for --resume).
|
||||
@@ -682,7 +689,7 @@ async def upload_cli_session(
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id)
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()}
|
||||
meta_encoded = json.dumps(meta).encode("utf-8")
|
||||
|
||||
session_result, meta_result = await asyncio.gather(
|
||||
@@ -709,18 +716,18 @@ async def upload_cli_session(
|
||||
)
|
||||
|
||||
|
||||
async def restore_cli_session(
|
||||
async def download_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
log_prefix: str = "[Transcript]",
|
||||
) -> CliSessionRestore | None:
|
||||
"""Download CLI session from GCS. Returns content + message_count, or None if not found.
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download CLI session from GCS. Returns content + message_count + mode, or None if not found.
|
||||
|
||||
Pure GCS operation — no disk I/O. The caller is responsible for writing
|
||||
content to disk if --resume is needed.
|
||||
|
||||
Returns a CliSessionRestore with the raw content and message_count watermark
|
||||
on success, or None if not available (first turn or upload failed).
|
||||
Returns a TranscriptDownload with the raw content, message_count watermark,
|
||||
and mode on success, or None if not available (first turn or upload failed).
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_path_from_parts(
|
||||
@@ -747,15 +754,18 @@ async def restore_cli_session(
|
||||
|
||||
content: bytes = content_result
|
||||
|
||||
# Parse message_count from companion meta — best-effort, default to 0.
|
||||
# Parse message_count and mode from companion meta — best-effort, defaults.
|
||||
message_count = 0
|
||||
mode: TranscriptMode = "sdk"
|
||||
if isinstance(meta_result, FileNotFoundError):
|
||||
pass # No meta — first upload or old version; default to 0
|
||||
pass # No meta — old upload; default to "sdk"
|
||||
elif isinstance(meta_result, BaseException):
|
||||
logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result)
|
||||
else:
|
||||
meta = json.loads(meta_result.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
raw_mode = meta.get("mode", "sdk")
|
||||
mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk"
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded CLI session (%dB, msg_count=%d)",
|
||||
@@ -763,7 +773,29 @@ async def restore_cli_session(
|
||||
len(content),
|
||||
message_count,
|
||||
)
|
||||
return CliSessionRestore(content=content, message_count=message_count)
|
||||
return TranscriptDownload(content=content, message_count=message_count, mode=mode)
|
||||
|
||||
|
||||
def detect_gap(
|
||||
download: TranscriptDownload,
|
||||
session_messages: list[ChatMessage],
|
||||
) -> list[ChatMessage]:
|
||||
"""Return chat-db messages after the transcript watermark (excluding current user turn).
|
||||
|
||||
Returns [] if transcript is current, watermark is zero, or the watermark
|
||||
position doesn't end on an assistant turn (misaligned watermark).
|
||||
"""
|
||||
if download.message_count == 0:
|
||||
return []
|
||||
wm = download.message_count
|
||||
total = len(session_messages)
|
||||
if wm >= total - 1:
|
||||
return []
|
||||
# Sanity: position wm-1 should be an assistant turn; misaligned watermark
|
||||
# means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context.
|
||||
if session_messages[wm - 1].role != "assistant":
|
||||
return []
|
||||
return list(session_messages[wm : total - 1])
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import MagicMock
|
||||
from backend.util import json
|
||||
|
||||
from .transcript import (
|
||||
CliSessionRestore,
|
||||
TranscriptDownload,
|
||||
_build_path_from_parts,
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
@@ -720,7 +720,7 @@ class TestUploadCliSession:
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
content = b'{"type":"assistant"}\n'
|
||||
@@ -731,7 +731,7 @@ class TestUploadCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000001",
|
||||
content=content,
|
||||
@@ -742,12 +742,12 @@ class TestUploadCliSession:
|
||||
assert mock_storage.store.call_count == 2
|
||||
|
||||
def test_uploads_companion_meta_json_with_message_count(self):
|
||||
"""upload_cli_session stores a companion .meta.json with message_count."""
|
||||
"""upload_transcript stores a companion .meta.json with message_count."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
content = b'{"type":"assistant"}\n'
|
||||
@@ -758,7 +758,7 @@ class TestUploadCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000010",
|
||||
content=content,
|
||||
@@ -781,7 +781,7 @@ class TestUploadCliSession:
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
from .transcript import upload_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.store.side_effect = [RuntimeError("gcs unavailable"), None]
|
||||
@@ -794,7 +794,7 @@ class TestUploadCliSession:
|
||||
):
|
||||
# Should not raise — failures are logged as warnings
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
upload_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000002",
|
||||
content=content,
|
||||
@@ -810,7 +810,7 @@ class TestRestoreCliSession:
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [
|
||||
@@ -824,7 +824,7 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
)
|
||||
@@ -832,12 +832,12 @@ class TestRestoreCliSession:
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_cli_session_restore_on_success_no_meta(self):
|
||||
"""Happy path with no meta.json: returns CliSessionRestore with message_count=0."""
|
||||
def test_returns_transcript_download_on_success_no_meta(self):
|
||||
"""Happy path with no meta.json: returns TranscriptDownload with message_count=0."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
session_id = "12345678-0000-0000-0000-000000000003"
|
||||
content = b'{"type":"assistant"}\n'
|
||||
@@ -851,27 +851,28 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, CliSessionRestore)
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.content == content
|
||||
assert result.message_count == 0
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_returns_cli_session_restore_with_message_count_from_meta(self):
|
||||
"""When meta.json is present, message_count is read from it."""
|
||||
def test_returns_transcript_download_with_message_count_from_meta(self):
|
||||
"""When meta.json is present, message_count and mode are read from it."""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
session_id = "12345678-0000-0000-0000-000000000005"
|
||||
content = b'{"type":"assistant"}\n'
|
||||
meta_bytes = json.dumps({"message_count": 7, "uploaded_at": 1234567.0}).encode()
|
||||
meta_bytes = json.dumps({"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0}).encode()
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [content, meta_bytes]
|
||||
@@ -882,22 +883,23 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id=session_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(result, CliSessionRestore)
|
||||
assert isinstance(result, TranscriptDownload)
|
||||
assert result.content == content
|
||||
assert result.message_count == 7
|
||||
assert result.mode == "sdk"
|
||||
|
||||
def test_returns_none_on_download_exception(self):
|
||||
"""Non-FileNotFoundError during retrieve logs warning and returns None."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
from .transcript import download_transcript
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = [
|
||||
@@ -911,7 +913,7 @@ class TestRestoreCliSession:
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
download_transcript(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000004",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user