mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend/copilot): address review comments on PR #12777
- transcript.py: replace isinstance(storage, GCS) branching in restore_cli_session with _build_path_from_parts (extensible to new backends without fallthrough bugs) - transcript.py: log real_path (post-realpath) instead of session_file in the boundary-check warning for clarity when symlinks are involved - service.py: add comment explaining asyncio.shield + CancelledError semantics to match the upload_transcript pattern above it - service_helpers_test.py: rename test_compaction_same_size_drops to test_compaction_invalid_transcript_drops (the test validates invalid compacted content, not same-size content) - transcript_test.py: add unit tests for upload_cli_session, restore_cli_session, and _cli_session_path covering the path boundary guard, FileNotFoundError fast-path, and success/failure return values - retry_scenarios_test.py: add integration test test_resume_skipped_when_cli_session_missing to TestStreamChatCompletion RetryIntegration — exercises the actual cli_restored branch so changes to service.py are caught immediately
This commit is contained in:
@@ -813,8 +813,12 @@ class TestRetryStateReset:
|
||||
|
||||
def test_cli_session_restore_failure_skips_resume(self):
|
||||
"""When restore_cli_session returns False, --resume is not used.
|
||||
The transcript builder is still populated for future upload_transcript."""
|
||||
# Simulate the logic from the primary resume path in service.py.
|
||||
The transcript builder is still populated for future upload_transcript.
|
||||
|
||||
This covers the guard on the cli_restored branch in service.py.
|
||||
For a full integration test exercising the actual service code path,
|
||||
see TestStreamChatCompletionRetryIntegration.test_resume_skipped_when_cli_session_missing.
|
||||
"""
|
||||
use_resume = False
|
||||
resume_file = None
|
||||
cli_restored = False # restore_cli_session returned False
|
||||
@@ -823,7 +827,6 @@ class TestRetryStateReset:
|
||||
use_resume = True
|
||||
resume_file = "sess-uuid"
|
||||
|
||||
# resume is not activated; builder state may still be loaded
|
||||
assert use_resume is False
|
||||
assert resume_file is None
|
||||
|
||||
@@ -999,7 +1002,10 @@ def _make_sdk_patches(
|
||||
return_value=MagicMock(content=original_transcript, message_count=2),
|
||||
),
|
||||
),
|
||||
(f"{_SVC}.restore_cli_session", dict(new_callable=AsyncMock, return_value=True)),
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=True),
|
||||
),
|
||||
(f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)),
|
||||
(f"{_SVC}.validate_transcript", dict(return_value=True)),
|
||||
(
|
||||
@@ -1878,3 +1884,65 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
for e in status_events
|
||||
), f"Expected 'retrying' or 'interrupted' in StreamStatus, got: {[e.message for e in status_events]}"
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_skipped_when_cli_session_missing(self):
|
||||
"""When restore_cli_session returns False, --resume is NOT passed to ClaudeSDKClient.
|
||||
|
||||
Exercises the actual service code path so any change to the cli_restored
|
||||
branch in service.py will be caught immediately by this test.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from backend.copilot.response_model import StreamStart
|
||||
from backend.copilot.sdk.service import stream_chat_completion_sdk
|
||||
|
||||
session = self._make_session()
|
||||
result_msg = self._make_result_message()
|
||||
original_transcript = _build_transcript(
|
||||
[("user", "prior question"), ("assistant", "prior answer")]
|
||||
)
|
||||
captured_options: dict = {}
|
||||
|
||||
def _client_factory(**kwargs):
|
||||
captured_options.update(kwargs)
|
||||
return self._make_client_mock(result_message=result_msg)
|
||||
|
||||
patches = _make_sdk_patches(
|
||||
session,
|
||||
original_transcript=original_transcript,
|
||||
compacted_transcript=None,
|
||||
client_side_effect=_client_factory,
|
||||
)
|
||||
# Override restore_cli_session to return False (CLI native session unavailable)
|
||||
patches = [
|
||||
(
|
||||
(
|
||||
f"{_SVC}.restore_cli_session",
|
||||
dict(new_callable=AsyncMock, return_value=False),
|
||||
)
|
||||
if p[0] == f"{_SVC}.restore_cli_session"
|
||||
else p
|
||||
)
|
||||
for p in patches
|
||||
]
|
||||
|
||||
events = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
for target, kwargs in patches:
|
||||
stack.enter_context(patch(target, **kwargs))
|
||||
async for event in stream_chat_completion_sdk(
|
||||
session_id="test-session-id",
|
||||
message="hello",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
session=session,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# --resume must NOT be passed when CLI native session was not restored
|
||||
assert "resume" not in captured_options, (
|
||||
f"--resume was set even though restore_cli_session returned False: "
|
||||
f"{captured_options}"
|
||||
)
|
||||
assert any(isinstance(e, StreamStart) for e in events)
|
||||
|
||||
@@ -2995,6 +2995,13 @@ async def stream_chat_completion_sdk(
|
||||
# The CLI writes its native session JSONL after each turn completes.
|
||||
# Uploading it here enables --resume on any pod (no pod affinity needed).
|
||||
# Runs after upload_transcript so both are available for the next turn.
|
||||
# asyncio.shield: same pattern as upload_transcript above — if the
|
||||
# outer finally-block coroutine is cancelled while awaiting shield,
|
||||
# the CancelledError propagates (BaseException, not caught by
|
||||
# `except Exception`) letting the caller handle cancellation, while
|
||||
# the shielded inner coroutine continues running to completion so the
|
||||
# upload is not lost. This is intentional and matches the pattern
|
||||
# used for upload_transcript immediately above.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestReduceContext:
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_same_size_drops(self) -> None:
|
||||
async def test_compaction_invalid_transcript_drops(self) -> None:
|
||||
# When validate_transcript returns False for compacted content, drop transcript.
|
||||
transcript = _build_transcript([("user", "hi"), ("assistant", "hello")])
|
||||
compacted = _build_transcript([("user", "hi"), ("assistant", "[summary]")])
|
||||
|
||||
@@ -708,7 +708,7 @@ async def upload_cli_session(
|
||||
logger.warning(
|
||||
"%s CLI session file outside projects base, skipping upload: %s",
|
||||
log_prefix,
|
||||
session_file,
|
||||
real_path,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -753,13 +753,9 @@ async def restore_cli_session(
|
||||
or upload failed), in which case the caller should not set --resume.
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id)
|
||||
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
||||
path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
path = f"local://{wid}/{fid}/{fname}"
|
||||
path = _build_path_from_parts(
|
||||
_cli_session_storage_path_parts(user_id, session_id), storage
|
||||
)
|
||||
|
||||
try:
|
||||
content = await storage.retrieve(path)
|
||||
|
||||
@@ -724,3 +724,155 @@ class TestValidateTranscript:
|
||||
def test_assistant_only_is_valid(self):
|
||||
content = _make_jsonl(ASST_ENTRY)
|
||||
assert validate_transcript(content) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI native session file helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCliSessionPath:
|
||||
def test_encodes_slashes_to_dashes(self):
|
||||
from .transcript import _cli_session_path, _projects_base
|
||||
|
||||
sdk_cwd = "/tmp/copilot-abc"
|
||||
result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc")
|
||||
base = _projects_base()
|
||||
assert result.startswith(base)
|
||||
# Encoded cwd replaces '/' with '-'
|
||||
assert "-tmp-copilot-abc" in result
|
||||
assert result.endswith(".jsonl")
|
||||
|
||||
def test_sanitizes_session_id(self):
|
||||
from .transcript import _cli_session_path
|
||||
|
||||
result = _cli_session_path("/tmp/cwd", "../../etc/passwd")
|
||||
# _sanitize_id strips non-hex/hyphen chars; path traversal impossible
|
||||
assert ".." not in result
|
||||
assert "passwd" not in result
|
||||
|
||||
|
||||
class TestUploadCliSession:
|
||||
def test_skips_upload_when_path_outside_projects_base(self, tmp_path):
|
||||
"""Files outside the CLI projects base are rejected without upload."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
|
||||
outside_path_obj = tmp_path / "outside.jsonl"
|
||||
outside_path_obj.write_bytes(b'{"type":"assistant"}\n')
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value="/nonexistent/projects/base",
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
):
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd="/tmp/copilot-test",
|
||||
)
|
||||
)
|
||||
|
||||
# storage.store must NOT be called for paths outside the base
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
def test_skips_upload_when_file_not_found(self, tmp_path):
|
||||
"""Missing CLI session file logs debug and skips upload silently."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import upload_cli_session
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
projects_base = str(tmp_path)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value=projects_base,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
):
|
||||
# session file doesn't exist — should not raise
|
||||
asyncio.run(
|
||||
upload_cli_session(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd=str(tmp_path),
|
||||
)
|
||||
)
|
||||
|
||||
mock_storage.store.assert_not_called()
|
||||
|
||||
|
||||
class TestRestoreCliSession:
|
||||
def test_returns_false_when_file_not_found_in_storage(self):
|
||||
"""Returns False (graceful degradation) when the session is missing."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.side_effect = FileNotFoundError("not found")
|
||||
|
||||
with patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd="/tmp/copilot-test",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path):
|
||||
"""Path traversal guard: rejects restoration outside the projects base."""
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from .transcript import restore_cli_session
|
||||
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.retrieve.return_value = b'{"type":"assistant"}\n'
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.transcript.get_workspace_storage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_storage,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.transcript._projects_base",
|
||||
return_value="/nonexistent/projects/base",
|
||||
),
|
||||
):
|
||||
result = asyncio.run(
|
||||
restore_cli_session(
|
||||
user_id="user-1",
|
||||
session_id="12345678-0000-0000-0000-000000000000",
|
||||
sdk_cwd="/tmp/copilot-test",
|
||||
)
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
Reference in New Issue
Block a user