diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 42dcff5c9c..ba6e2291bb 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -2382,6 +2382,160 @@ async def _seed_transcript( return _seeded, True, len(_prior) +@dataclass +class _RestoreResult: + """Return value from ``_restore_cli_session_for_turn``.""" + + transcript_content: str = "" + transcript_covers_prefix: bool = True + use_resume: bool = False + resume_file: str | None = None + transcript_msg_count: int = 0 + + +async def _restore_cli_session_for_turn( + user_id: str | None, + session_id: str, + session: "ChatSession", + sdk_cwd: str, + transcript_builder: "TranscriptBuilder", + log_prefix: str, +) -> _RestoreResult: + """Download, validate and restore a CLI session for ``--resume`` on this turn. + + Performs a single GCS round-trip to fetch the session bytes + message_count + watermark. Falls back to DB-message reconstruction when GCS has no session + (first turn or upload missed). + + Returns a ``_RestoreResult`` with all transcript-related state ready for the + caller to merge into its local variables. + """ + result = _RestoreResult() + + if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1): + return result + + try: + cli_restore = await download_transcript( + user_id, session_id, log_prefix=log_prefix + ) + except Exception as restore_err: + logger.warning( + "%s CLI session restore failed, continuing without --resume: %s", + log_prefix, + restore_err, + ) + cli_restore = None + + # Only attempt --resume for SDK-written transcripts. + # Baseline-written transcripts use TranscriptBuilder format (synthetic IDs, + # stripped fields) that may not be valid for --resume. + if cli_restore is not None and cli_restore.mode != "sdk": + logger.info( + "%s Transcript written by mode=%r, skipping --resume — will reconstruct from DB", + log_prefix, + cli_restore.mode, + ) + cli_restore = None + + # Validate, strip, and write to disk — delegate to helper to reduce + # function complexity. Writing an invalid/corrupt file to disk then + # falling back to "no --resume" would cause the CLI to fail with + # "Session ID already in use" because the file exists at the expected + # session path, so we validate BEFORE any disk write. + stripped = "" + if cli_restore is not None and sdk_cwd: + stripped, ok = _process_cli_restore( + cli_restore, sdk_cwd, session_id, log_prefix + ) + if not ok: + result.transcript_covers_prefix = False + cli_restore = None + + if cli_restore is None and sdk_cwd: + # Validation failed or GCS returned no session. Delete any + # existing local session file so the CLI doesn't reject the + # session_id with "Session ID already in use". T1 may have + # left a valid file at this path; we clear it so the fallback + # path (session_id= without --resume) can create a new session. + _stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id)) + if Path(_stale_path).exists() and _stale_path.startswith( + projects_base() + os.sep + ): + try: + Path(_stale_path).unlink() + logger.debug( + "%s Removed stale local CLI session file for clean fallback", + log_prefix, + ) + except OSError as _unlink_err: + logger.debug( + "%s Failed to remove stale local session file: %s", + log_prefix, + _unlink_err, + ) + + if cli_restore is not None: + result.transcript_content = stripped + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.use_resume = True + result.resume_file = session_id + result.transcript_msg_count = cli_restore.message_count + return result + + # No CLI session in GCS — reconstruct from DB messages as last-resort fallback. + prior = session.messages[:-1] + reconstructed = _session_messages_to_transcript(prior) + if reconstructed and sdk_cwd: + result.transcript_content = reconstructed + transcript_builder.load_previous(reconstructed, log_prefix=log_prefix) + result.transcript_msg_count = len(prior) + result.transcript_covers_prefix = True + # Write the reconstructed transcript to disk so the CLI can + # --resume on this turn (avoids context-free response and + # seeds the native session for cross-pod restore next turn). + _reconstructed_bytes = reconstructed.encode("utf-8") + if _write_cli_session_to_disk( + _reconstructed_bytes, sdk_cwd, session_id, log_prefix + ): + result.use_resume = True + result.resume_file = session_id + logger.info( + "%s Reconstructed transcript from %d session messages " + "and wrote to disk for --resume", + log_prefix, + len(prior), + ) + else: + logger.info( + "%s Reconstructed transcript from %d session messages " + "(disk write failed — running without --resume this turn)", + log_prefix, + len(prior), + ) + elif reconstructed: + result.transcript_content = reconstructed + transcript_builder.load_previous(reconstructed, log_prefix=log_prefix) + result.transcript_msg_count = len(prior) + result.transcript_covers_prefix = True + logger.info( + "%s Reconstructed transcript from %d session messages " + "(no sdk_cwd — running without --resume this turn)", + log_prefix, + len(prior), + ) + else: + logger.warning( + "%s No session available and reconstruction produced empty output " + "(%d messages in session)", + log_prefix, + len(session.messages), + ) + result.transcript_covers_prefix = False + + return result + + async def stream_chat_completion_sdk( session_id: str, message: str | None = None, @@ -2605,128 +2759,14 @@ async def stream_chat_completion_sdk( # Restore CLI session — single GCS round-trip covers both --resume and builder state. # message_count watermark lives in the companion .meta.json alongside the session file. - transcript_msg_count = 0 - if config.claude_agent_use_resume and user_id and len(session.messages) > 1: - try: - cli_restore = await download_transcript( - user_id, session_id, log_prefix=log_prefix - ) - except Exception as restore_err: - logger.warning( - "%s CLI session restore failed, continuing without --resume: %s", - log_prefix, - restore_err, - ) - cli_restore = None - - # Only attempt --resume for SDK-written transcripts. - # Baseline-written transcripts use TranscriptBuilder format (synthetic IDs, - # stripped fields) that may not be valid for --resume. - if cli_restore is not None and cli_restore.mode != "sdk": - logger.info( - "%s Transcript written by mode=%r, skipping --resume — will reconstruct from DB", - log_prefix, - cli_restore.mode, - ) - cli_restore = None - - # Validate, strip, and write to disk — delegate to helper to reduce - # function complexity. Writing an invalid/corrupt file to disk then - # falling back to "no --resume" would cause the CLI to fail with - # "Session ID already in use" because the file exists at the expected - # session path, so we validate BEFORE any disk write. - stripped = "" - if cli_restore is not None and sdk_cwd: - stripped, ok = _process_cli_restore( - cli_restore, sdk_cwd, session_id, log_prefix - ) - if not ok: - transcript_covers_prefix = False - cli_restore = None - - if cli_restore is None and sdk_cwd: - # Validation failed or GCS returned no session. Delete any - # existing local session file so the CLI doesn't reject the - # session_id with "Session ID already in use". T1 may have - # left a valid file at this path; we clear it so the fallback - # path (session_id= without --resume) can create a new session. - _stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id)) - if Path(_stale_path).exists() and _stale_path.startswith( - projects_base() + os.sep - ): - try: - Path(_stale_path).unlink() - logger.debug( - "%s Removed stale local CLI session file for clean fallback", - log_prefix, - ) - except OSError as _unlink_err: - logger.debug( - "%s Failed to remove stale local session file: %s", - log_prefix, - _unlink_err, - ) - - if cli_restore is not None: - transcript_content = stripped - transcript_builder.load_previous(stripped, log_prefix=log_prefix) - use_resume = True - resume_file = session_id - transcript_msg_count = cli_restore.message_count - else: - # No CLI session in GCS — reconstruct from DB messages as last-resort fallback. - prior = session.messages[:-1] - reconstructed = _session_messages_to_transcript(prior) - if reconstructed and sdk_cwd: - transcript_content = reconstructed - transcript_builder.load_previous( - reconstructed, log_prefix=log_prefix - ) - transcript_msg_count = len(prior) - transcript_covers_prefix = True - # Write the reconstructed transcript to disk so the CLI can - # --resume on this turn (avoids context-free response and - # seeds the native session for cross-pod restore next turn). - _reconstructed_bytes = reconstructed.encode("utf-8") - if _write_cli_session_to_disk( - _reconstructed_bytes, sdk_cwd, session_id, log_prefix - ): - use_resume = True - resume_file = session_id - logger.info( - "%s Reconstructed transcript from %d session messages " - "and wrote to disk for --resume", - log_prefix, - len(prior), - ) - else: - logger.info( - "%s Reconstructed transcript from %d session messages " - "(disk write failed — running without --resume this turn)", - log_prefix, - len(prior), - ) - elif reconstructed: - transcript_content = reconstructed - transcript_builder.load_previous( - reconstructed, log_prefix=log_prefix - ) - transcript_msg_count = len(prior) - transcript_covers_prefix = True - logger.info( - "%s Reconstructed transcript from %d session messages " - "(no sdk_cwd — running without --resume this turn)", - log_prefix, - len(prior), - ) - else: - logger.warning( - "%s No session available and reconstruction produced empty output " - "(%d messages in session)", - log_prefix, - len(session.messages), - ) - transcript_covers_prefix = False + _restore = await _restore_cli_session_for_turn( + user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix + ) + transcript_content = _restore.transcript_content + transcript_covers_prefix = _restore.transcript_covers_prefix + use_resume = _restore.use_resume + resume_file = _restore.resume_file + transcript_msg_count = _restore.transcript_msg_count yield StreamStart(messageId=message_id, sessionId=session_id)