mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend/copilot): baseline always uploads when GCS has no transcript
_load_prior_transcript was returning (False, None) for missing/invalid transcripts, preventing the upload guard from firing. The intent was to protect against overwriting a *newer* GCS version — but a missing or corrupt file has nothing worth protecting. Only download errors (unknown GCS state) should suppress upload now. Root cause of the session 7803bde1 bug: the baseline turn ran with 23 session messages, no transcript existed in GCS (first baseline turn in that session), _load_prior_transcript returned (False, None), and should_upload_transcript gated the upload to False. The SDK's subsequent turn found no baseline JSONL and fell back to full DB reconstruction. Also renames transcript_covers_prefix → transcript_upload_safe throughout to accurately reflect the flag's semantics.
This commit is contained in:
@@ -704,16 +704,15 @@ async def _compress_session_messages(
|
||||
|
||||
|
||||
def should_upload_transcript(
|
||||
user_id: str | None, transcript_covers_prefix: bool
|
||||
user_id: str | None, upload_safe: bool
|
||||
) -> bool:
|
||||
"""Return ``True`` when the caller should upload the final transcript.
|
||||
|
||||
Uploads require a logged-in user (for the storage key) *and* a
|
||||
transcript that covered the session prefix when loaded — otherwise
|
||||
we'd be overwriting a more complete version in storage with a
|
||||
partial one built from just the current turn.
|
||||
Uploads require a logged-in user (for the storage key) *and* a safe
|
||||
upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a
|
||||
newer version that we'd be overwriting.
|
||||
"""
|
||||
return bool(user_id) and transcript_covers_prefix
|
||||
return bool(user_id) and upload_safe
|
||||
|
||||
|
||||
def _append_gap_to_builder(
|
||||
@@ -778,11 +777,11 @@ async def _load_prior_transcript(
|
||||
) -> tuple[bool, "TranscriptDownload | None"]:
|
||||
"""Download and load the prior CLI session into ``transcript_builder``.
|
||||
|
||||
Returns a tuple of (covers_prefix, transcript_download):
|
||||
- ``covers_prefix`` is ``True`` when the loaded session fully covers the
|
||||
session prefix; ``False`` otherwise (missing, invalid, or download error).
|
||||
Callers should suppress uploads when this is ``False`` to avoid overwriting
|
||||
a more complete version in storage.
|
||||
Returns a tuple of (upload_safe, transcript_download):
|
||||
- ``upload_safe`` is ``True`` when it is safe to upload at the end of this
|
||||
turn. Upload is suppressed only for **download errors** (unknown GCS
|
||||
state) — missing and invalid files return ``True`` because there is
|
||||
nothing in GCS worth protecting against overwriting.
|
||||
- ``transcript_download`` is a ``TranscriptDownload`` with str content
|
||||
(pre-decoded and stripped) when available, or ``None`` when no valid
|
||||
transcript could be loaded. Callers pass this to
|
||||
@@ -794,11 +793,14 @@ async def _load_prior_transcript(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Session restore failed: %s", e)
|
||||
# Unknown GCS state — be conservative, skip upload.
|
||||
return False, None
|
||||
|
||||
if restore is None:
|
||||
logger.debug("[Baseline] No CLI session available")
|
||||
return False, None
|
||||
logger.debug("[Baseline] No CLI session available — will upload fresh")
|
||||
# Nothing in GCS to protect; allow upload so the first baseline turn
|
||||
# writes the initial transcript snapshot.
|
||||
return True, None
|
||||
|
||||
content_bytes = restore.content
|
||||
try:
|
||||
@@ -809,12 +811,14 @@ async def _load_prior_transcript(
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("[Baseline] CLI session content is not valid UTF-8")
|
||||
return False, None
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
stripped = strip_for_upload(raw_str)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning("[Baseline] CLI session content invalid after strip")
|
||||
return False, None
|
||||
# Corrupt file in GCS; overwriting with a valid one is better.
|
||||
return True, None
|
||||
|
||||
transcript_builder.load_previous(stripped, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
@@ -965,7 +969,7 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
transcript_upload_safe = True
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
@@ -985,7 +989,7 @@ async def stream_chat_completion_baseline(
|
||||
transcript_download: TranscriptDownload | None = None
|
||||
if user_id and len(session.messages) > 1:
|
||||
(
|
||||
(transcript_covers_prefix, transcript_download),
|
||||
(transcript_upload_safe, transcript_download),
|
||||
(base_system_prompt, understanding),
|
||||
) = await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
@@ -1382,7 +1386,7 @@ async def stream_chat_completion_baseline(
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
|
||||
if user_id and should_upload_transcript(user_id, transcript_upload_safe):
|
||||
await _upload_final_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -137,25 +137,27 @@ class TestLoadPriorTranscript:
|
||||
assert builder.entry_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
async def test_missing_transcript_allows_upload(self):
|
||||
"""Nothing in GCS → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers, dl = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
async def test_invalid_transcript_allows_upload(self):
|
||||
"""Corrupt file in GCS → overwriting with a valid one is better."""
|
||||
builder = TranscriptBuilder()
|
||||
restore = TranscriptDownload(
|
||||
content=b'{"type":"progress","uuid":"a"}\n',
|
||||
@@ -166,14 +168,14 @@ class TestLoadPriorTranscript:
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=restore),
|
||||
):
|
||||
covers, dl = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_messages=_make_session_messages("user", "assistant"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert builder.is_empty
|
||||
|
||||
@@ -558,7 +560,7 @@ class TestTranscriptLifecycle:
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
user_id="user-1", upload_safe=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
@@ -625,14 +627,13 @@ class TestTranscriptLifecycle:
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
should_upload_transcript(user_id=None, upload_safe=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior session → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
"""No prior session → upload is safe; the turn writes the first snapshot."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
@@ -645,24 +646,22 @@ class TestTranscriptLifecycle:
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers, dl = await _load_prior_transcript(
|
||||
upload_safe, dl = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_messages=_make_session_messages("user"),
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No restore: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete session with a single-turn snapshot.
|
||||
assert covers is False
|
||||
# Nothing in GCS → upload is safe so the first baseline turn
|
||||
# can write the initial transcript snapshot.
|
||||
assert upload_safe is True
|
||||
assert dl is None
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
user_id="user-1", upload_safe=upload_safe
|
||||
)
|
||||
is False
|
||||
is True
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user