diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 3ebc259d1c..bc23115adf 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -704,16 +704,15 @@ async def _compress_session_messages( def should_upload_transcript( - user_id: str | None, transcript_covers_prefix: bool + user_id: str | None, upload_safe: bool ) -> bool: """Return ``True`` when the caller should upload the final transcript. - Uploads require a logged-in user (for the storage key) *and* a - transcript that covered the session prefix when loaded — otherwise - we'd be overwriting a more complete version in storage with a - partial one built from just the current turn. + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. """ - return bool(user_id) and transcript_covers_prefix + return bool(user_id) and upload_safe def _append_gap_to_builder( @@ -778,11 +777,11 @@ async def _load_prior_transcript( ) -> tuple[bool, "TranscriptDownload | None"]: """Download and load the prior CLI session into ``transcript_builder``. - Returns a tuple of (covers_prefix, transcript_download): - - ``covers_prefix`` is ``True`` when the loaded session fully covers the - session prefix; ``False`` otherwise (missing, invalid, or download error). - Callers should suppress uploads when this is ``False`` to avoid overwriting - a more complete version in storage. + Returns a tuple of (upload_safe, transcript_download): + - ``upload_safe`` is ``True`` when it is safe to upload at the end of this + turn. Upload is suppressed only for **download errors** (unknown GCS + state) — missing and invalid files return ``True`` because there is + nothing in GCS worth protecting against overwriting. - ``transcript_download`` is a ``TranscriptDownload`` with str content (pre-decoded and stripped) when available, or ``None`` when no valid transcript could be loaded. Callers pass this to @@ -794,11 +793,14 @@ async def _load_prior_transcript( ) except Exception as e: logger.warning("[Baseline] Session restore failed: %s", e) + # Unknown GCS state — be conservative, skip upload. return False, None if restore is None: - logger.debug("[Baseline] No CLI session available") - return False, None + logger.debug("[Baseline] No CLI session available — will upload fresh") + # Nothing in GCS to protect; allow upload so the first baseline turn + # writes the initial transcript snapshot. + return True, None content_bytes = restore.content try: @@ -809,12 +811,14 @@ async def _load_prior_transcript( ) except UnicodeDecodeError: logger.warning("[Baseline] CLI session content is not valid UTF-8") - return False, None + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None stripped = strip_for_upload(raw_str) if not validate_transcript(stripped): logger.warning("[Baseline] CLI session content invalid after strip") - return False, None + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None transcript_builder.load_previous(stripped, log_prefix="[Baseline]") logger.info( @@ -965,7 +969,7 @@ async def stream_chat_completion_baseline( # --- Transcript support (feature parity with SDK path) --- transcript_builder = TranscriptBuilder() - transcript_covers_prefix = True + transcript_upload_safe = True # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. @@ -985,7 +989,7 @@ async def stream_chat_completion_baseline( transcript_download: TranscriptDownload | None = None if user_id and len(session.messages) > 1: ( - (transcript_covers_prefix, transcript_download), + (transcript_upload_safe, transcript_download), (base_system_prompt, understanding), ) = await asyncio.gather( _load_prior_transcript( @@ -1382,7 +1386,7 @@ async def stream_chat_completion_baseline( stop_reason=STOP_REASON_END_TURN, ) - if user_id and should_upload_transcript(user_id, transcript_covers_prefix): + if user_id and should_upload_transcript(user_id, transcript_upload_safe): await _upload_final_transcript( user_id=user_id, session_id=session_id, diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index 9c4aa9cc12..bda348668f 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -137,25 +137,27 @@ class TestLoadPriorTranscript: assert builder.entry_count == 4 @pytest.mark.asyncio - async def test_missing_transcript_returns_false(self): + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=None), ): - covers, dl = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True assert dl is None assert builder.is_empty @pytest.mark.asyncio - async def test_invalid_transcript_returns_false(self): + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with a valid one is better.""" builder = TranscriptBuilder() restore = TranscriptDownload( content=b'{"type":"progress","uuid":"a"}\n', @@ -166,14 +168,14 @@ class TestLoadPriorTranscript: "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=restore), ): - covers, dl = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True assert dl is None assert builder.is_empty @@ -558,7 +560,7 @@ class TestTranscriptLifecycle: # --- 3. Gate + upload --- assert ( should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers + user_id="user-1", upload_safe=covers ) is True ) @@ -625,14 +627,13 @@ class TestTranscriptLifecycle: ) assert ( - should_upload_transcript(user_id=None, transcript_covers_prefix=True) + should_upload_transcript(user_id=None, upload_safe=True) is False ) @pytest.mark.asyncio async def test_lifecycle_missing_download_still_uploads_new_content(self): - """No prior session → covers defaults to True in the service, - new turn should upload cleanly.""" + """No prior session → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() upload_mock = AsyncMock(return_value=None) with ( @@ -645,24 +646,22 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers, dl = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", session_messages=_make_session_messages("user"), transcript_builder=builder, ) - # No restore: covers is False, so the production path would - # skip upload. This protects against overwriting a future - # more-complete session with a single-turn snapshot. - assert covers is False + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial transcript snapshot. + assert upload_safe is True assert dl is None assert ( should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers + user_id="user-1", upload_safe=upload_safe ) - is False + is True ) - upload_mock.assert_not_awaited() # ---------------------------------------------------------------------------