diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 057671d3e3..3baa4b6b5c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -18,7 +18,6 @@ from backend.copilot import stream_registry from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn -from backend.copilot.message_dedup import acquire_dedup_lock from backend.copilot.model import ( ChatMessage, ChatSession, @@ -192,6 +191,8 @@ class SessionDetailResponse(BaseModel): active_stream: ActiveStreamInfo | None = None # Present if stream is still active has_more_messages: bool = False oldest_sequence: int | None = None + newest_sequence: int | None = None + forward_paginated: bool = False total_prompt_tokens: int = 0 total_completion_tokens: int = 0 metadata: ChatSessionMetadata = ChatSessionMetadata() @@ -456,52 +457,113 @@ async def update_session_title_route( async def get_session( session_id: str, user_id: Annotated[str, Security(auth.get_user_id)], - limit: int = Query(default=50, ge=1, le=200), - before_sequence: int | None = Query(default=None, ge=0), + limit: int = Query( + default=50, + ge=1, + le=200, + description="Maximum number of messages to return.", + ), + before_sequence: int | None = Query( + default=None, + ge=0, + description=( + "Backward pagination cursor. Return messages with sequence number " + "strictly less than this value. Used by active-session load-more. " + "Mutually exclusive with after_sequence." + ), + ), + after_sequence: int | None = Query( + default=None, + ge=0, + description=( + "Forward pagination cursor. Return messages with sequence number " + "strictly greater than this value. Used by completed-session load-more. " + "Mutually exclusive with before_sequence." + ), + ), ) -> SessionDetailResponse: """ Retrieve the details of a specific chat session. - Supports cursor-based pagination via ``limit`` and ``before_sequence``. - When no pagination params are provided, returns the most recent messages. + Supports cursor-based pagination via ``limit``, ``before_sequence``, and + ``after_sequence``. The two cursor parameters are mutually exclusive. - Args: - session_id: The unique identifier for the desired chat session. - user_id: The authenticated user's ID. - limit: Maximum number of messages to return (1-200, default 50). - before_sequence: Return messages with sequence < this value (cursor). - - Returns: - SessionDetailResponse: Details for the requested session, including - active_stream info and pagination metadata. + On the initial load (no cursor provided) of a completed session, messages + are returned in forward order starting from sequence 0 so the user always + sees their initial prompt. Active sessions use the legacy newest-first + order so streaming context is preserved. """ + if before_sequence is not None and after_sequence is not None: + raise HTTPException( + status_code=400, + detail="before_sequence and after_sequence are mutually exclusive", + ) + + is_initial_load = before_sequence is None and after_sequence is None + + # Check active stream before the DB query on initial loads so we can + # choose the correct pagination direction (forward for completed sessions, + # newest-first for active ones). + active_session = None + last_message_id = None + if is_initial_load: + active_session, last_message_id = await stream_registry.get_active_session( + session_id, user_id + ) + + # Completed sessions on initial load start from sequence 0 so the user's + # initial prompt is always visible. Active sessions keep the legacy + # newest-first behavior to preserve streaming context. + from_start = is_initial_load and active_session is None + forward_paginated = from_start or after_sequence is not None + page = await get_chat_messages_paginated( - session_id, limit, before_sequence, user_id=user_id + session_id, + limit, + before_sequence=before_sequence, + after_sequence=after_sequence, + from_start=from_start, + user_id=user_id, ) if page is None: raise NotFoundError(f"Session {session_id} not found.") + + # Close the TOCTOU window: if the session was active at pre-check, re-verify + # after the DB fetch. The session may have completed between the two awaits, + # which would have caused messages to be fetched newest-first even though the + # session is now complete. Re-fetch from seq 0 so the initial prompt is + # always visible. + if is_initial_load and active_session is not None: + post_active, _ = await stream_registry.get_active_session(session_id, user_id) + if post_active is None: + active_session = None + last_message_id = None + from_start = True + forward_paginated = True + page = await get_chat_messages_paginated( + session_id, + limit, + before_sequence=None, + after_sequence=None, + from_start=True, + user_id=user_id, + ) + if page is None: + raise NotFoundError(f"Session {session_id} not found.") + messages = [ _strip_injected_context(message.model_dump()) for message in page.messages ] - # Only check active stream on initial load (not on "load more" requests) active_stream_info = None - if before_sequence is None: - active_session, last_message_id = await stream_registry.get_active_session( - session_id, user_id + if active_session and last_message_id is not None: + active_stream_info = ActiveStreamInfo( + turn_id=active_session.turn_id, + last_message_id=last_message_id, ) - logger.info( - f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, " - f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" - ) - if active_session: - active_stream_info = ActiveStreamInfo( - turn_id=active_session.turn_id, - last_message_id=last_message_id, - ) # Skip session metadata on "load more" — frontend only needs messages - if before_sequence is not None: + if not is_initial_load: return SessionDetailResponse( id=page.session.session_id, created_at=page.session.started_at.isoformat(), @@ -511,6 +573,8 @@ async def get_session( active_stream=None, has_more_messages=page.has_more, oldest_sequence=page.oldest_sequence, + newest_sequence=page.newest_sequence, + forward_paginated=forward_paginated, total_prompt_tokens=0, total_completion_tokens=0, ) @@ -527,6 +591,8 @@ async def get_session( active_stream=active_stream_info, has_more_messages=page.has_more, oldest_sequence=page.oldest_sequence, + newest_sequence=page.newest_sequence, + forward_paginated=forward_paginated, total_prompt_tokens=total_prompt, total_completion_tokens=total_completion, metadata=page.session.metadata, @@ -872,9 +938,6 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None - # Capture the original message text BEFORE any mutation (attachment enrichment) - # so the idempotency hash is stable across retries. - original_message = request.message if request.file_ids and user_id: # Filter to valid UUIDs only to prevent DB abuse valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] @@ -903,58 +966,36 @@ async def stream_chat_post( ) request.message += files_block - # ── Idempotency guard ──────────────────────────────────────────────────── - # Blocks duplicate executor tasks from concurrent/retried POSTs. - # See backend/copilot/message_dedup.py for the full lifecycle description. - dedup_lock = None - if request.is_user_message: - dedup_lock = await acquire_dedup_lock( - session_id, original_message, sanitized_file_ids - ) - if dedup_lock is None and (original_message or sanitized_file_ids): - - async def _empty_sse() -> AsyncGenerator[str, None]: - yield StreamFinish().to_sse() - yield "data: [DONE]\n\n" - - return StreamingResponse( - _empty_sse(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "X-Accel-Buffering": "no", - "Connection": "keep-alive", - "x-vercel-ai-ui-message-stream": "v1", - }, - ) - # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't - # saved yet. append_and_save_message re-fetches inside a lock to prevent - # message loss from concurrent requests. - # - # If any of these operations raises, release the dedup lock before propagating - # so subsequent retries are not blocked for 30 s. - try: - if request.message: - message = ChatMessage( - role="user" if request.is_user_message else "assistant", - content=request.message, - ) - if request.is_user_message: - track_user_message( - user_id=user_id, - session_id=session_id, - message_length=len(request.message), - ) - logger.info(f"[STREAM] Saving user message to session {session_id}") + # saved yet. append_and_save_message returns None when a duplicate is + # detected — in that case skip enqueue to avoid processing the message twice. + is_duplicate_message = False + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + is_duplicate_message = ( await append_and_save_message(session_id, message) - logger.info(f"[STREAM] User message saved for session {session_id}") + ) is None + logger.info(f"[STREAM] User message saved for session {session_id}") + if not is_duplicate_message and request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) - # Create a task in the stream registry for reconnection support + # Create a task in the stream registry for reconnection support. + # For duplicate messages, skip create_session entirely so the infra-retry + # client subscribes to the *existing* turn's Redis stream and receives the + # in-progress executor output rather than an empty stream. + turn_id = "" + if not is_duplicate_message: turn_id = str(uuid4()) log_meta["turn_id"] = turn_id - session_create_start = time.perf_counter() await stream_registry.create_session( session_id=session_id, @@ -972,7 +1013,6 @@ async def stream_chat_post( } }, ) - await enqueue_copilot_turn( session_id=session_id, user_id=user_id, @@ -984,10 +1024,10 @@ async def stream_chat_post( mode=request.mode, model=request.model, ) - except Exception: - if dedup_lock: - await dedup_lock.release() - raise + else: + logger.info( + f"[STREAM] Duplicate message detected for session {session_id}, skipping enqueue" + ) setup_time = (time.perf_counter() - stream_start_time) * 1000 logger.info( @@ -1011,12 +1051,6 @@ async def stream_chat_post( subscriber_queue = None first_chunk_yielded = False chunks_yielded = 0 - # True for every exit path except GeneratorExit (client disconnect). - # On disconnect the backend turn is still running — releasing the lock - # there would reopen the infra-retry duplicate window. The 30 s TTL - # is the fallback. All other exits (normal finish, early return, error) - # should release so the user can re-send the same message. - release_dedup_lock_on_exit = True try: # Subscribe from the position we captured before enqueuing # This avoids replaying old messages while catching all new ones @@ -1028,7 +1062,7 @@ async def stream_chat_post( if subscriber_queue is None: yield StreamFinish().to_sse() - return # finally releases dedup_lock + return # Read from the subscriber queue and yield to SSE logger.info( @@ -1070,7 +1104,7 @@ async def stream_chat_post( } }, ) - break # finally releases dedup_lock + break except asyncio.TimeoutError: yield StreamHeartbeat().to_sse() @@ -1086,7 +1120,6 @@ async def stream_chat_post( } }, ) - release_dedup_lock_on_exit = False except Exception as e: elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error( @@ -1101,10 +1134,7 @@ async def stream_chat_post( code="stream_error", ).to_sse() yield StreamFinish().to_sse() - # finally releases dedup_lock finally: - if dedup_lock and release_dedup_lock_on_exit: - await dedup_lock.release() # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 597aad01ad..a1ad07deae 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -133,21 +133,12 @@ def test_stream_chat_rejects_too_many_file_ids(): assert response.status_code == 422 -def _mock_stream_internals( - mocker: pytest_mock.MockerFixture, - *, - redis_set_returns: object = True, -): +def _mock_stream_internals(mocker: pytest_mock.MockerFixture): """Mock the async internals of stream_chat_post so tests can exercise - validation and enrichment logic without needing Redis/RabbitMQ. - - Args: - redis_set_returns: Value returned by the mocked Redis ``set`` call. - ``True`` (default) simulates a fresh key (new message); - ``None`` simulates a collision (duplicate blocked). + validation and enrichment logic without needing RabbitMQ. Returns: - A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so + A namespace with ``save`` and ``enqueue`` mock objects so callers can make additional assertions about side-effects. """ import types @@ -158,7 +149,7 @@ def _mock_stream_internals( ) mock_save = mocker.patch( "backend.api.features.chat.routes.append_and_save_message", - return_value=None, + return_value=MagicMock(), # non-None = message was saved (not a duplicate) ) mock_registry = mocker.MagicMock() mock_registry.create_session = mocker.AsyncMock(return_value=None) @@ -174,15 +165,9 @@ def _mock_stream_internals( "backend.api.features.chat.routes.track_user_message", return_value=None, ) - mock_redis = AsyncMock() - mock_redis.set = AsyncMock(return_value=redis_set_returns) - mocker.patch( - "backend.copilot.message_dedup.get_redis_async", - new_callable=AsyncMock, - return_value=mock_redis, + return types.SimpleNamespace( + save=mock_save, enqueue=mock_enqueue, registry=mock_registry ) - ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue) - return ns def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture): @@ -211,6 +196,29 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture): assert response.status_code == 200 +# ─── Duplicate message dedup ────────────────────────────────────────── + + +def test_stream_chat_skips_enqueue_for_duplicate_message( + mocker: pytest_mock.MockerFixture, +): + """When append_and_save_message returns None (duplicate detected), + enqueue_copilot_turn and stream_registry.create_session must NOT be called + to avoid double-processing and to prevent overwriting the active stream's + turn_id in Redis (which would cause reconnecting clients to miss the response).""" + mocks = _mock_stream_internals(mocker) + # Override save to return None — signalling a duplicate + mocks.save.return_value = None + + response = client.post( + "/sessions/sess-1/stream", + json={"message": "hello"}, + ) + assert response.status_code == 200 + mocks.enqueue.assert_not_called() + mocks.registry.create_session.assert_not_called() + + # ─── UUID format filtering ───────────────────────────────────────────── @@ -706,237 +714,6 @@ class TestStripInjectedContext: assert result["content"] == "hello" -# ─── Idempotency / duplicate-POST guard ────────────────────────────── - - -def test_stream_chat_blocks_duplicate_post_returns_empty_sse( - mocker: pytest_mock.MockerFixture, -) -> None: - """A second POST with the same message within the 30-s window must return - an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the - turn complete without creating a ghost response.""" - # redis_set_returns=None simulates a collision: the NX key already exists. - ns = _mock_stream_internals(mocker, redis_set_returns=None) - - response = client.post( - "/sessions/sess-dup/stream", - json={"message": "duplicate message", "is_user_message": True}, - ) - - assert response.status_code == 200 - body = response.text - # The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator. - assert '"finish"' in body - assert "[DONE]" in body - # The empty SSE response must include the AI SDK protocol header so the - # frontend treats it as a valid stream and marks the turn complete. - assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1" - # The duplicate guard must prevent save/enqueue side effects. - ns.save.assert_not_called() - ns.enqueue.assert_not_called() - - -def test_stream_chat_first_post_proceeds_normally( - mocker: pytest_mock.MockerFixture, -) -> None: - """The first POST (Redis NX key set successfully) must proceed through the - normal streaming path — no early return.""" - ns = _mock_stream_internals(mocker, redis_set_returns=True) - - response = client.post( - "/sessions/sess-new/stream", - json={"message": "first message", "is_user_message": True}, - ) - - assert response.status_code == 200 - # Redis set must have been called once with the NX flag. - ns.redis.set.assert_called_once() - call_kwargs = ns.redis.set.call_args - assert call_kwargs.kwargs.get("nx") is True - - -def test_stream_chat_dedup_skipped_for_non_user_messages( - mocker: pytest_mock.MockerFixture, -) -> None: - """System/assistant messages (is_user_message=False) bypass the dedup - guard — they are injected programmatically and must always be processed.""" - ns = _mock_stream_internals(mocker, redis_set_returns=None) - - response = client.post( - "/sessions/sess-sys/stream", - json={"message": "system context", "is_user_message": False}, - ) - - # Even though redis_set_returns=None (would block a user message), - # the endpoint must proceed because is_user_message=False. - assert response.status_code == 200 - ns.redis.set.assert_not_called() - - -def test_stream_chat_dedup_hash_uses_original_message_not_mutated( - mocker: pytest_mock.MockerFixture, -) -> None: - """The dedup hash must be computed from the original request message, - not the mutated version that has the [Attached files] block appended. - A file_id is sent so the route actually appends the [Attached files] block, - exercising the mutation path — the hash must still match the original text.""" - import hashlib - - ns = _mock_stream_internals(mocker, redis_set_returns=True) - - file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" - # Mock workspace + prisma so the attachment block is actually appended. - mocker.patch( - "backend.api.features.chat.routes.get_or_create_workspace", - return_value=type("W", (), {"id": "ws-1"})(), - ) - fake_file = type( - "F", - (), - { - "id": file_id, - "name": "doc.pdf", - "mimeType": "application/pdf", - "sizeBytes": 1024, - }, - )() - mock_prisma = mocker.MagicMock() - mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file]) - mocker.patch( - "prisma.models.UserWorkspaceFile.prisma", - return_value=mock_prisma, - ) - - response = client.post( - "/sessions/sess-hash/stream", - json={ - "message": "plain message", - "is_user_message": True, - "file_ids": [file_id], - }, - ) - - assert response.status_code == 200 - ns.redis.set.assert_called_once() - call_args = ns.redis.set.call_args - dedup_key = call_args.args[0] - - # Hash must use the original message + sorted file IDs, not the mutated text. - expected_hash = hashlib.sha256( - f"sess-hash:plain message:{file_id}".encode() - ).hexdigest()[:16] - expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}" - assert dedup_key == expected_key, ( - f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — " - "hash may be using mutated message or wrong inputs" - ) - - -def test_stream_chat_dedup_key_released_after_stream_finish( - mocker: pytest_mock.MockerFixture, -) -> None: - """The dedup Redis key must be deleted after the turn completes (when - subscriber_queue is None the route yields StreamFinish immediately and - should release the key so the user can re-send the same message).""" - from unittest.mock import AsyncMock as _AsyncMock - - # Set up all internals manually so we can control subscribe_to_session. - mocker.patch( - "backend.api.features.chat.routes._validate_and_get_session", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.append_and_save_message", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.enqueue_copilot_turn", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.track_user_message", - return_value=None, - ) - mock_registry = mocker.MagicMock() - mock_registry.create_session = _AsyncMock(return_value=None) - # None → early-finish path: StreamFinish yielded immediately, dedup key released. - mock_registry.subscribe_to_session = _AsyncMock(return_value=None) - mocker.patch( - "backend.api.features.chat.routes.stream_registry", - mock_registry, - ) - mock_redis = mocker.AsyncMock() - mock_redis.set = _AsyncMock(return_value=True) - mocker.patch( - "backend.copilot.message_dedup.get_redis_async", - new_callable=_AsyncMock, - return_value=mock_redis, - ) - - response = client.post( - "/sessions/sess-finish/stream", - json={"message": "hello", "is_user_message": True}, - ) - - assert response.status_code == 200 - body = response.text - assert '"finish"' in body - # The dedup key must be released so intentional re-sends are allowed. - mock_redis.delete.assert_called_once() - - -def test_stream_chat_dedup_key_released_even_when_redis_delete_raises( - mocker: pytest_mock.MockerFixture, -) -> None: - """The route must not crash when the dedup Redis delete fails on the - subscriber_queue-is-None early-finish path (except Exception: pass).""" - from unittest.mock import AsyncMock as _AsyncMock - - mocker.patch( - "backend.api.features.chat.routes._validate_and_get_session", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.append_and_save_message", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.enqueue_copilot_turn", - return_value=None, - ) - mocker.patch( - "backend.api.features.chat.routes.track_user_message", - return_value=None, - ) - mock_registry = mocker.MagicMock() - mock_registry.create_session = _AsyncMock(return_value=None) - mock_registry.subscribe_to_session = _AsyncMock(return_value=None) - mocker.patch( - "backend.api.features.chat.routes.stream_registry", - mock_registry, - ) - mock_redis = mocker.AsyncMock() - mock_redis.set = _AsyncMock(return_value=True) - # Make the delete raise so the except-pass branch is exercised. - mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone")) - mocker.patch( - "backend.copilot.message_dedup.get_redis_async", - new_callable=_AsyncMock, - return_value=mock_redis, - ) - - # Should not raise even though delete fails. - response = client.post( - "/sessions/sess-finish-err/stream", - json={"message": "hello", "is_user_message": True}, - ) - - assert response.status_code == 200 - assert '"finish"' in response.text - # delete must have been attempted — the except-pass branch silenced the error. - mock_redis.delete.assert_called_once() - - # ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── @@ -980,3 +757,146 @@ def test_disconnect_stream_returns_404_when_session_missing( assert response.status_code == 404 mock_disconnect.assert_not_awaited() + + +# ─── GET /sessions/{session_id} — forward/backward pagination ────────────────── + + +def _make_paginated_messages( + mocker: pytest_mock.MockerFixture, *, has_more: bool = False +): + """Return a mock PaginatedMessages and configure the DB patch.""" + from datetime import UTC, datetime + + from backend.copilot.db import PaginatedMessages + from backend.copilot.model import ChatMessage, ChatSessionInfo, ChatSessionMetadata + + now = datetime.now(UTC) + session_info = ChatSessionInfo( + session_id="sess-1", + user_id=TEST_USER_ID, + usage=[], + started_at=now, + updated_at=now, + metadata=ChatSessionMetadata(), + ) + page = PaginatedMessages( + messages=[ChatMessage(role="user", content="hello", sequence=0)], + has_more=has_more, + oldest_sequence=0, + newest_sequence=0, + session=session_info, + ) + mock_paginate = mocker.patch( + "backend.api.features.chat.routes.get_chat_messages_paginated", + new_callable=AsyncMock, + return_value=page, + ) + return page, mock_paginate + + +def test_get_session_completed_returns_forward_paginated( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """Completed sessions (no active stream) return forward_paginated=True.""" + _make_paginated_messages(mocker) + mocker.patch( + "backend.api.features.chat.routes.stream_registry.get_active_session", + new_callable=AsyncMock, + return_value=(None, None), + ) + + response = client.get("/sessions/sess-1") + + assert response.status_code == 200 + data = response.json() + assert data["forward_paginated"] is True + assert data["newest_sequence"] == 0 + + +def test_get_session_active_returns_backward_paginated( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """Active sessions (with running stream) return forward_paginated=False.""" + from backend.copilot.stream_registry import ActiveSession + + _make_paginated_messages(mocker) + active = MagicMock(spec=ActiveSession) + active.turn_id = "turn-1" + mocker.patch( + "backend.api.features.chat.routes.stream_registry.get_active_session", + new_callable=AsyncMock, + return_value=(active, "msg-1"), + ) + + response = client.get("/sessions/sess-1") + + assert response.status_code == 200 + data = response.json() + assert data["forward_paginated"] is False + assert data["active_stream"] is not None + assert data["active_stream"]["turn_id"] == "turn-1" + + +def test_get_session_after_sequence_returns_forward_paginated( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """after_sequence param returns forward_paginated=True; no stream check needed.""" + _, mock_paginate = _make_paginated_messages(mocker) + + response = client.get("/sessions/sess-1?after_sequence=10") + + assert response.status_code == 200 + data = response.json() + assert data["forward_paginated"] is True + call_kwargs = mock_paginate.call_args + assert call_kwargs.kwargs.get("after_sequence") == 10 + assert call_kwargs.kwargs.get("before_sequence") is None + + +def test_get_session_both_cursors_returns_400( + test_user_id: str, +) -> None: + """Sending both before_sequence and after_sequence returns 400.""" + response = client.get("/sessions/sess-1?before_sequence=5&after_sequence=10") + + assert response.status_code == 400 + + +def test_get_session_toctou_refetch_when_session_completes_mid_request( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """Race condition: session was active at pre-check but completes before DB fetch. + + The route should detect the race via a post-fetch re-check, then re-fetch + from seq 0 so the initial prompt is always visible. + """ + from backend.copilot.stream_registry import ActiveSession + + page, mock_paginate = _make_paginated_messages(mocker) + active = MagicMock(spec=ActiveSession) + active.turn_id = "turn-1" + + # First call: session appears active. Second call: session has completed. + mock_get_active = mocker.patch( + "backend.api.features.chat.routes.stream_registry.get_active_session", + new_callable=AsyncMock, + side_effect=[(active, "msg-1"), (None, None)], + ) + + response = client.get("/sessions/sess-1") + + assert response.status_code == 200 + data = response.json() + # Post-race: session is now completed → forward_paginated=True, no stream + assert data["forward_paginated"] is True + assert data["active_stream"] is None + # The DB was queried twice: once newest-first, once from_start=True + assert mock_paginate.call_count == 2 + assert mock_get_active.call_count == 2 + second_call = mock_paginate.call_args_list[1] + assert second_call.kwargs.get("from_start") is True diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index fcfc896ea2..0e7357bad3 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -43,6 +43,25 @@ config = Config() integration_creds_manager = IntegrationCredentialsManager() +async def _fetch_execution_counts(user_id: str, graph_ids: list[str]) -> dict[str, int]: + """Fetch execution counts per graph in a single batched query.""" + if not graph_ids: + return {} + rows = await prisma.models.AgentGraphExecution.prisma().group_by( + by=["agentGraphId"], + where={ + "userId": user_id, + "agentGraphId": {"in": graph_ids}, + "isDeleted": False, + }, + count=True, + ) + return { + row["agentGraphId"]: int((row.get("_count") or {}).get("_all") or 0) + for row in rows + } + + async def list_library_agents( user_id: str, search_term: Optional[str] = None, @@ -137,12 +156,18 @@ async def list_library_agents( logger.debug(f"Retrieved {len(library_agents)} library agents for user #{user_id}") + graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId] + execution_counts = await _fetch_execution_counts(user_id, graph_ids) + # Only pass valid agents to the response valid_library_agents: list[library_model.LibraryAgent] = [] for agent in library_agents: try: - library_agent = library_model.LibraryAgent.from_db(agent) + library_agent = library_model.LibraryAgent.from_db( + agent, + execution_count_override=execution_counts.get(agent.agentGraphId), + ) valid_library_agents.append(library_agent) except Exception as e: # Skip this agent if there was an error @@ -214,12 +239,18 @@ async def list_favorite_library_agents( f"Retrieved {len(library_agents)} favorite library agents for user #{user_id}" ) + graph_ids = [a.agentGraphId for a in library_agents if a.agentGraphId] + execution_counts = await _fetch_execution_counts(user_id, graph_ids) + # Only pass valid agents to the response valid_library_agents: list[library_model.LibraryAgent] = [] for agent in library_agents: try: - library_agent = library_model.LibraryAgent.from_db(agent) + library_agent = library_model.LibraryAgent.from_db( + agent, + execution_count_override=execution_counts.get(agent.agentGraphId), + ) valid_library_agents.append(library_agent) except Exception as e: # Skip this agent if there was an error diff --git a/autogpt_platform/backend/backend/api/features/library/db_test.py b/autogpt_platform/backend/backend/api/features/library/db_test.py index 5e3e36ac63..562a0bfdfd 100644 --- a/autogpt_platform/backend/backend/api/features/library/db_test.py +++ b/autogpt_platform/backend/backend/api/features/library/db_test.py @@ -65,6 +65,11 @@ async def test_get_library_agents(mocker): ) mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={}), + ) + # Call function result = await db.list_library_agents("test-user") @@ -353,3 +358,136 @@ async def test_create_library_agent_uses_upsert(): # Verify update branch restores soft-deleted/archived agents assert data["update"]["isDeleted"] is False assert data["update"]["isArchived"] is False + + +@pytest.mark.asyncio +async def test_list_favorite_library_agents(mocker): + mock_library_agents = [ + prisma.models.LibraryAgent( + id="fav1", + userId="test-user", + agentGraphId="agent-fav", + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=False, + isDeleted=False, + isArchived=False, + createdAt=datetime.now(), + updatedAt=datetime.now(), + isFavorite=True, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id="agent-fav", + version=1, + name="Favorite Agent", + description="My Favorite", + userId="other-user", + isActive=True, + createdAt=datetime.now(), + ), + ) + ] + + mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma") + mock_library_agent.return_value.find_many = mocker.AsyncMock( + return_value=mock_library_agents + ) + mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={"agent-fav": 7}), + ) + + result = await db.list_favorite_library_agents("test-user") + + assert len(result.agents) == 1 + assert result.agents[0].id == "fav1" + assert result.agents[0].name == "Favorite Agent" + assert result.agents[0].graph_id == "agent-fav" + assert result.pagination.total_items == 1 + assert result.pagination.total_pages == 1 + assert result.pagination.current_page == 1 + assert result.pagination.page_size == 50 + + +@pytest.mark.asyncio +async def test_list_library_agents_skips_failed_agent(mocker): + """Agents that fail parsing should be skipped — covers the except branch.""" + mock_library_agents = [ + prisma.models.LibraryAgent( + id="ua-bad", + userId="test-user", + agentGraphId="agent-bad", + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=False, + isDeleted=False, + isArchived=False, + createdAt=datetime.now(), + updatedAt=datetime.now(), + isFavorite=False, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id="agent-bad", + version=1, + name="Bad Agent", + description="", + userId="other-user", + isActive=True, + createdAt=datetime.now(), + ), + ) + ] + + mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma") + mock_library_agent.return_value.find_many = mocker.AsyncMock( + return_value=mock_library_agents + ) + mock_library_agent.return_value.count = mocker.AsyncMock(return_value=1) + + mocker.patch( + "backend.api.features.library.db._fetch_execution_counts", + new=mocker.AsyncMock(return_value={}), + ) + mocker.patch( + "backend.api.features.library.model.LibraryAgent.from_db", + side_effect=Exception("parse error"), + ) + + result = await db.list_library_agents("test-user") + + assert len(result.agents) == 0 + assert result.pagination.total_items == 1 + + +@pytest.mark.asyncio +async def test_fetch_execution_counts_empty_graph_ids(): + result = await db._fetch_execution_counts("user-1", []) + assert result == {} + + +@pytest.mark.asyncio +async def test_fetch_execution_counts_uses_group_by(mocker): + mock_prisma = mocker.patch("prisma.models.AgentGraphExecution.prisma") + mock_prisma.return_value.group_by = mocker.AsyncMock( + return_value=[ + {"agentGraphId": "graph-1", "_count": {"_all": 5}}, + {"agentGraphId": "graph-2", "_count": {"_all": 2}}, + ] + ) + + result = await db._fetch_execution_counts( + "user-1", ["graph-1", "graph-2", "graph-3"] + ) + + assert result == {"graph-1": 5, "graph-2": 2} + mock_prisma.return_value.group_by.assert_called_once_with( + by=["agentGraphId"], + where={ + "userId": "user-1", + "agentGraphId": {"in": ["graph-1", "graph-2", "graph-3"]}, + "isDeleted": False, + }, + count=True, + ) diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index 7211a7ebfe..26251a2cd1 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -223,6 +223,7 @@ class LibraryAgent(pydantic.BaseModel): sub_graphs: Optional[list[prisma.models.AgentGraph]] = None, store_listing: Optional[prisma.models.StoreListing] = None, profile: Optional[prisma.models.Profile] = None, + execution_count_override: Optional[int] = None, ) -> "LibraryAgent": """ Factory method that constructs a LibraryAgent from a Prisma LibraryAgent @@ -258,10 +259,14 @@ class LibraryAgent(pydantic.BaseModel): status = status_result.status new_output = status_result.new_output - execution_count = len(executions) + execution_count = ( + execution_count_override + if execution_count_override is not None + else len(executions) + ) success_rate: float | None = None avg_correctness_score: float | None = None - if execution_count > 0: + if executions and execution_count > 0: success_count = sum( 1 for e in executions diff --git a/autogpt_platform/backend/backend/api/features/library/model_test.py b/autogpt_platform/backend/backend/api/features/library/model_test.py index a32b19322d..31924a1793 100644 --- a/autogpt_platform/backend/backend/api/features/library/model_test.py +++ b/autogpt_platform/backend/backend/api/features/library/model_test.py @@ -1,11 +1,66 @@ import datetime +import prisma.enums import prisma.models import pytest from . import model as library_model +def _make_library_agent( + *, + graph_id: str = "g1", + executions: list | None = None, +) -> prisma.models.LibraryAgent: + return prisma.models.LibraryAgent( + id="la1", + userId="u1", + agentGraphId=graph_id, + settings="{}", # type: ignore + agentGraphVersion=1, + isCreatedByUser=True, + isDeleted=False, + isArchived=False, + createdAt=datetime.datetime.now(), + updatedAt=datetime.datetime.now(), + isFavorite=False, + useGraphIsActiveVersion=True, + AgentGraph=prisma.models.AgentGraph( + id=graph_id, + version=1, + name="Agent", + description="Desc", + userId="u1", + isActive=True, + createdAt=datetime.datetime.now(), + Executions=executions, + ), + ) + + +def test_from_db_execution_count_override_covers_success_rate(): + """Covers execution_count_override is not None branch and executions/count > 0 block.""" + now = datetime.datetime.now(datetime.timezone.utc) + exec1 = prisma.models.AgentGraphExecution( + id="exec-1", + agentGraphId="g1", + agentGraphVersion=1, + userId="u1", + executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED, + createdAt=now, + updatedAt=now, + isDeleted=False, + isShared=False, + ) + agent = _make_library_agent(executions=[exec1]) + + result = library_model.LibraryAgent.from_db(agent, execution_count_override=1) + + assert result.execution_count == 1 + assert result.success_rate is not None + assert result.success_rate == 100.0 + + @pytest.mark.asyncio async def test_agent_preset_from_db(test_user_id: str): # Create mock DB agent diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py index 7a7ec518c6..c20e0d0ceb 100644 --- a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -4,291 +4,802 @@ from unittest.mock import AsyncMock, Mock import fastapi import fastapi.testclient +import pytest import pytest_mock +import stripe from autogpt_libs.auth.jwt_utils import get_jwt_payload from prisma.enums import SubscriptionTier -from .v1 import v1_router - -app = fastapi.FastAPI() -app.include_router(v1_router) - -client = fastapi.testclient.TestClient(app) +from .v1 import _validate_checkout_redirect_url, v1_router TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" +TEST_FRONTEND_ORIGIN = "https://app.example.com" -def setup_auth(app: fastapi.FastAPI): +@pytest.fixture() +def client() -> fastapi.testclient.TestClient: + """Fresh FastAPI app + client per test with auth override applied. + + Using a fixture avoids the leaky global-app + try/finally teardown pattern: + if a test body raises before teardown_auth runs, dependency overrides were + previously leaking into subsequent tests. + """ + app = fastapi.FastAPI() + app.include_router(v1_router) + def override_get_jwt_payload(request: fastapi.Request) -> dict[str, str]: return {"sub": TEST_USER_ID, "role": "user", "email": "test@example.com"} app.dependency_overrides[get_jwt_payload] = override_get_jwt_payload + try: + yield fastapi.testclient.TestClient(app) + finally: + app.dependency_overrides.clear() -def teardown_auth(app: fastapi.FastAPI): - app.dependency_overrides.clear() +@pytest.fixture(autouse=True) +def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None: + """Pin the configured frontend origin used by the open-redirect guard.""" + from backend.api.features import v1 as v1_mod + + mocker.patch.object( + v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN + ) + + +@pytest.mark.parametrize( + "url,expected", + [ + # Valid URLs matching the configured frontend origin + (f"{TEST_FRONTEND_ORIGIN}/success", True), + (f"{TEST_FRONTEND_ORIGIN}/cancel?ref=abc", True), + # Wrong origin + ("https://evil.example.org/phish", False), + ("https://evil.example.org", False), + # @ in URL (user:pass@host attack) + (f"https://attacker.example.com@{TEST_FRONTEND_ORIGIN}/ok", False), + # Backslash normalisation attack + (f"https:{TEST_FRONTEND_ORIGIN}\\@attacker.example.com/ok", False), + # javascript: scheme + ("javascript:alert(1)", False), + # Empty string + ("", False), + # Control character (U+0000) in URL + (f"{TEST_FRONTEND_ORIGIN}/ok\x00evil", False), + # Non-http scheme + (f"ftp://{TEST_FRONTEND_ORIGIN}/ok", False), + ], +) +def test_validate_checkout_redirect_url( + url: str, + expected: bool, + mocker: pytest_mock.MockFixture, +) -> None: + """_validate_checkout_redirect_url rejects adversarial inputs.""" + from backend.api.features import v1 as v1_mod + + mocker.patch.object( + v1_mod.settings.config, "frontend_base_url", TEST_FRONTEND_ORIGIN + ) + assert _validate_checkout_redirect_url(url) is expected def test_get_subscription_status_pro( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: """GET /credits/subscription returns PRO tier with Stripe price for a PRO user.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO - mock_price = Mock() - mock_price.unit_amount = 1999 # $19.99 + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro" if tier == SubscriptionTier.PRO else None - async def mock_price_id(tier: SubscriptionTier) -> str | None: - return "price_pro" if tier == SubscriptionTier.PRO else None + async def mock_stripe_price_amount(price_id: str) -> int: + return 1999 if price_id == "price_pro" else 0 - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.get_subscription_price_id", - side_effect=mock_price_id, - ) - mocker.patch( - "backend.api.features.v1.stripe.Price.retrieve", - return_value=mock_price, - ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1._get_stripe_price_amount", + side_effect=mock_stripe_price_amount, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=500, + ) - response = client.get("/credits/subscription") + response = client.get("/credits/subscription") - assert response.status_code == 200 - data = response.json() - assert data["tier"] == "PRO" - assert data["monthly_cost"] == 1999 - assert data["tier_costs"]["PRO"] == 1999 - assert data["tier_costs"]["BUSINESS"] == 0 - assert data["tier_costs"]["FREE"] == 0 - finally: - teardown_auth(app) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + assert data["monthly_cost"] == 1999 + assert data["tier_costs"]["PRO"] == 1999 + assert data["tier_costs"]["BUSINESS"] == 0 + assert data["tier_costs"]["FREE"] == 0 + assert data["proration_credit_cents"] == 500 def test_get_subscription_status_defaults_to_free( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: """GET /credits/subscription when subscription_tier is None defaults to FREE.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = None + mock_user = Mock() + mock_user.subscription_tier = None - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.get_subscription_price_id", - new_callable=AsyncMock, - return_value=None, - ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) - response = client.get("/credits/subscription") + response = client.get("/credits/subscription") - assert response.status_code == 200 - data = response.json() - assert data["tier"] == SubscriptionTier.FREE.value - assert data["monthly_cost"] == 0 - assert data["tier_costs"] == { - "FREE": 0, - "PRO": 0, - "BUSINESS": 0, - "ENTERPRISE": 0, - } - finally: - teardown_auth(app) + assert response.status_code == 200 + data = response.json() + assert data["tier"] == SubscriptionTier.FREE.value + assert data["monthly_cost"] == 0 + assert data["tier_costs"] == { + "FREE": 0, + "PRO": 0, + "BUSINESS": 0, + "ENTERPRISE": 0, + } + assert data["proration_credit_cents"] == 0 + + +def test_get_subscription_status_stripe_error_falls_back_to_zero( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription returns cost=0 when Stripe price fetch fails (returns None). + + _get_stripe_price_amount returns None on StripeError so the error state is + not cached. The endpoint must treat None as 0 — not raise or return invalid data. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro" if tier == SubscriptionTier.PRO else None + + async def mock_stripe_price_amount_none(price_id: str) -> None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1._get_stripe_price_amount", + side_effect=mock_stripe_price_amount_none, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "PRO" + # When Stripe returns None, cost falls back to 0 + assert data["monthly_cost"] == 0 + assert data["tier_costs"]["PRO"] == 0 def test_update_subscription_tier_free_no_payment( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: """POST /credits/subscription to FREE tier when payment disabled skips Stripe.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO - async def mock_feature_disabled(*args, **kwargs): - return False + async def mock_feature_disabled(*args, **kwargs): + return False - async def mock_set_tier(*args, **kwargs): - pass + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) + mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_disabled, - ) - mocker.patch( - "backend.api.features.v1.set_subscription_tier", - side_effect=mock_set_tier, - ) + response = client.post("/credits/subscription", json={"tier": "FREE"}) - response = client.post("/credits/subscription", json={"tier": "FREE"}) - - assert response.status_code == 200 - assert response.json()["url"] == "" - finally: - teardown_auth(app) + assert response.status_code == 200 + assert response.json()["url"] == "" def test_update_subscription_tier_paid_beta_user( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """POST /credits/subscription for paid tier when payment disabled sets tier directly.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + """POST /credits/subscription for paid tier when payment disabled returns 422.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE - async def mock_feature_disabled(*args, **kwargs): - return False + async def mock_feature_disabled(*args, **kwargs): + return False - async def mock_set_tier(*args, **kwargs): - pass + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_disabled, + ) - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_disabled, - ) - mocker.patch( - "backend.api.features.v1.set_subscription_tier", - side_effect=mock_set_tier, - ) + response = client.post("/credits/subscription", json={"tier": "PRO"}) - response = client.post("/credits/subscription", json={"tier": "PRO"}) - - assert response.status_code == 200 - assert response.json()["url"] == "" - finally: - teardown_auth(app) + assert response.status_code == 422 + assert "not available" in response.json()["detail"] def test_update_subscription_tier_paid_requires_urls( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: """POST /credits/subscription for paid tier without success/cancel URLs returns 422.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE - async def mock_feature_enabled(*args, **kwargs): - return True + async def mock_feature_enabled(*args, **kwargs): + return True - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, - ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) - response = client.post("/credits/subscription", json={"tier": "PRO"}) + response = client.post("/credits/subscription", json={"tier": "PRO"}) - assert response.status_code == 422 - finally: - teardown_auth(app) + assert response.status_code == 422 def test_update_subscription_tier_creates_checkout( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: """POST /credits/subscription creates Stripe Checkout Session for paid upgrade.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.FREE + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE - async def mock_feature_enabled(*args, **kwargs): - return True + async def mock_feature_enabled(*args, **kwargs): + return True - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, - ) - mocker.patch( - "backend.api.features.v1.create_subscription_checkout", - new_callable=AsyncMock, - return_value="https://checkout.stripe.com/pay/cs_test_abc", - ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + return_value="https://checkout.stripe.com/pay/cs_test_abc", + ) - response = client.post( - "/credits/subscription", - json={ - "tier": "PRO", - "success_url": "https://app.example.com/success", - "cancel_url": "https://app.example.com/cancel", - }, - ) + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) - assert response.status_code == 200 - assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc" - finally: - teardown_auth(app) + assert response.status_code == 200 + assert response.json()["url"] == "https://checkout.stripe.com/pay/cs_test_abc" -def test_update_subscription_tier_free_with_payment_cancels_stripe( +def test_update_subscription_tier_rejects_open_redirect( + client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """Downgrading to FREE cancels active Stripe subscription when payment is enabled.""" - setup_auth(app) - try: - mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO + """POST /credits/subscription rejects success/cancel URLs outside the frontend origin.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.FREE - async def mock_feature_enabled(*args, **kwargs): - return True + async def mock_feature_enabled(*args, **kwargs): + return True - mock_cancel = mocker.patch( - "backend.api.features.v1.cancel_stripe_subscription", - new_callable=AsyncMock, - ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) - async def mock_set_tier(*args, **kwargs): - pass + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": "https://evil.example.org/phish", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) - mocker.patch( - "backend.api.features.v1.get_user_by_id", - new_callable=AsyncMock, - return_value=mock_user, - ) - mocker.patch( - "backend.api.features.v1.set_subscription_tier", - side_effect=mock_set_tier, - ) - mocker.patch( - "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, - ) + assert response.status_code == 422 + checkout_mock.assert_not_awaited() - response = client.post("/credits/subscription", json={"tier": "FREE"}) - assert response.status_code == 200 - mock_cancel.assert_awaited_once() - finally: - teardown_auth(app) +def test_update_subscription_tier_enterprise_blocked( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """ENTERPRISE users cannot self-service change tiers — must get 403.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.ENTERPRISE + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 403 + set_tier_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_is_noop( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription for the user's current paid tier returns 200 with empty URL. + + Without this guard a duplicate POST (double-click, browser retry, stale page) would + create a second Stripe Checkout Session for the same price, potentially billing the + user twice until the webhook reconciliation fires. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE schedules Stripe cancellation at period end. + + The DB tier must NOT be updated immediately — the customer.subscription.deleted + webhook fires at period end and downgrades to FREE then. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mock_cancel = mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + new_callable=AsyncMock, + ) + mock_set_tier = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + mock_cancel.assert_awaited_once() + mock_set_tier.assert_not_awaited() + + +def test_update_subscription_tier_free_cancel_failure_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE returns 502 with a generic error (no Stripe detail leakage).""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + async def mock_feature_enabled(*args, **kwargs): + return True + + mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + side_effect=stripe.StripeError( + "You did not provide an API key — internal detail that must not leak" + ), + ) + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + side_effect=mock_feature_enabled, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 502 + detail = response.json()["detail"] + # The raw Stripe error message must not appear in the client-facing detail. + assert "API key" not in detail + assert "contact support" in detail.lower() + + +def test_stripe_webhook_unconfigured_secret_returns_503( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Stripe webhook endpoint returns 503 when STRIPE_WEBHOOK_SECRET is not set. + + An empty webhook secret allows HMAC forgery: an attacker can compute a valid + HMAC signature over the same empty key. The handler must reject all requests + when the secret is unconfigured rather than proceeding with signature verification. + """ + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="", + ) + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=fake"}, + ) + assert response.status_code == 503 + + +def test_stripe_webhook_dispatches_subscription_events( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/stripe_webhook routes customer.subscription.created to sync handler.""" + stripe_sub_obj = { + "id": "sub_test", + "customer": "cus_test", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro"}}]}, + } + event = { + "type": "customer.subscription.created", + "data": {"object": stripe_sub_obj}, + } + + # Ensure the webhook secret guard passes (non-empty secret required). + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(stripe_sub_obj) + + +def test_stripe_webhook_dispatches_invoice_payment_failed( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/stripe_webhook routes invoice.payment_failed to the failure handler.""" + invoice_obj = { + "customer": "cus_test", + "subscription": "sub_test", + "amount_due": 1999, + } + event = { + "type": "invoice.payment_failed", + "data": {"object": invoice_obj}, + } + + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + failure_mock = mocker.patch( + "backend.api.features.v1.handle_subscription_payment_failure", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + failure_mock.assert_awaited_once_with(invoice_obj) + + +def test_update_subscription_tier_paid_to_paid_modifies_subscription( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription modifies existing subscription for paid→paid changes.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_admin_granted_paid_to_paid_updates_db_directly( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Admin-granted paid tier users are NOT sent to Stripe checkout for paid→paid changes. + + When modify_stripe_subscription_for_tier returns False (no Stripe subscription + found — admin-granted tier), the endpoint must update the DB tier directly and + return 200 with url="", rather than falling through to Checkout Session creation. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + # Return False = no Stripe subscription (admin-granted tier) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=False, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + # DB tier updated directly — no Stripe Checkout Session created + set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.BUSINESS) + checkout_mock.assert_not_awaited() + + +def test_update_subscription_tier_paid_to_paid_stripe_error_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """POST /credits/subscription returns 502 when Stripe modification fails.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + side_effect=stripe.StripeError("connection error"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + + +def test_update_subscription_tier_free_no_stripe_subscription( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Downgrading to FREE when no Stripe subscription exists updates DB tier directly. + + Admin-granted paid tiers have no associated Stripe subscription. When such a + user requests a self-service downgrade, cancel_stripe_subscription returns False + (nothing to cancel), so the endpoint must immediately call set_subscription_tier + rather than waiting for a webhook that will never arrive. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + # Simulate no active Stripe subscriptions — returns False + cancel_mock = mocker.patch( + "backend.api.features.v1.cancel_stripe_subscription", + new_callable=AsyncMock, + return_value=False, + ) + set_tier_mock = mocker.patch( + "backend.api.features.v1.set_subscription_tier", + new_callable=AsyncMock, + ) + + response = client.post("/credits/subscription", json={"tier": "FREE"}) + + assert response.status_code == 200 + assert response.json()["url"] == "" + cancel_mock.assert_awaited_once_with(TEST_USER_ID) + # DB tier must be updated immediately — no webhook will fire for a missing sub + set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE) diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 5767cebd94..ab0b69071d 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -5,7 +5,8 @@ import time import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import Annotated, Any, Literal, Sequence, get_args +from typing import Annotated, Any, Literal, Sequence, cast, get_args +from urllib.parse import urlparse import pydantic import stripe @@ -54,8 +55,11 @@ from backend.data.credit import ( cancel_stripe_subscription, create_subscription_checkout, get_auto_top_up, + get_proration_credit_cents, get_subscription_price_id, get_user_credit_model, + handle_subscription_payment_failure, + modify_stripe_subscription_for_tier, set_auto_top_up, set_subscription_tier, sync_subscription_from_stripe, @@ -699,9 +703,72 @@ class SubscriptionCheckoutResponse(BaseModel): class SubscriptionStatusResponse(BaseModel): - tier: str - monthly_cost: int - tier_costs: dict[str, int] + tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"] + monthly_cost: int # amount in cents (Stripe convention) + tier_costs: dict[str, int] # tier name -> amount in cents + proration_credit_cents: int # unused portion of current sub to convert on upgrade + + +def _validate_checkout_redirect_url(url: str) -> bool: + """Return True if `url` matches the configured frontend origin. + + Prevents open-redirect: attackers must not be able to supply arbitrary + success_url/cancel_url that Stripe will redirect users to after checkout. + + Pre-parse rejection rules (applied before urlparse): + - Backslashes (``\\``) are normalised differently across parsers/browsers. + - Control characters (U+0000–U+001F) are not valid in URLs and may confuse + some URL-parsing implementations. + """ + # Reject characters that can confuse URL parsers before any parsing. + if "\\" in url: + return False + if any(ord(c) < 0x20 for c in url): + return False + + allowed = settings.config.frontend_base_url or settings.config.platform_base_url + if not allowed: + # No configured origin — refuse to validate rather than allow arbitrary URLs. + return False + try: + parsed = urlparse(url) + allowed_parsed = urlparse(allowed) + except ValueError: + return False + if parsed.scheme not in ("http", "https"): + return False + # Reject ``user:pass@host`` authority tricks — ``@`` in the netloc component + # can trick browsers into connecting to a different host than displayed. + # ``@`` in query/fragment is harmless and must be allowed. + if "@" in parsed.netloc: + return False + return ( + parsed.scheme == allowed_parsed.scheme + and parsed.netloc == allowed_parsed.netloc + ) + + +@cached(ttl_seconds=300, maxsize=32, cache_none=False) +async def _get_stripe_price_amount(price_id: str) -> int | None: + """Return the unit_amount (cents) for a Stripe Price ID, cached for 5 minutes. + + Returns ``None`` on transient Stripe errors. ``cache_none=False`` opts out + of caching the ``None`` sentinel so the next request retries Stripe instead + of being served a stale "no price" for the rest of the TTL window. Callers + should treat ``None`` as an unknown price and fall back to 0. + + Stripe prices rarely change; caching avoids a ~200-600 ms Stripe round-trip on + every GET /credits/subscription page load and reduces quota consumption. + """ + try: + price = await run_in_threadpool(stripe.Price.retrieve, price_id) + return price.unit_amount or 0 + except stripe.StripeError: + logger.warning( + "Failed to retrieve Stripe price %s — returning None (not cached)", + price_id, + ) + return None @v1_router.get( @@ -722,21 +789,26 @@ async def get_subscription_status( *[get_subscription_price_id(t) for t in paid_tiers] ) - tier_costs: dict[str, int] = {"FREE": 0, "ENTERPRISE": 0} - for t, price_id in zip(paid_tiers, price_ids): - cost = 0 - if price_id: - try: - price = await run_in_threadpool(stripe.Price.retrieve, price_id) - cost = price.unit_amount or 0 - except stripe.StripeError: - pass + tier_costs: dict[str, int] = { + SubscriptionTier.FREE.value: 0, + SubscriptionTier.ENTERPRISE.value: 0, + } + + async def _cost(pid: str | None) -> int: + return (await _get_stripe_price_amount(pid) or 0) if pid else 0 + + costs = await asyncio.gather(*[_cost(pid) for pid in price_ids]) + for t, cost in zip(paid_tiers, costs): tier_costs[t.value] = cost + current_monthly_cost = tier_costs.get(tier.value, 0) + proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost) + return SubscriptionStatusResponse( tier=tier.value, - monthly_cost=tier_costs.get(tier.value, 0), + monthly_cost=current_monthly_cost, tier_costs=tier_costs, + proration_credit_cents=proration_credit, ) @@ -766,24 +838,125 @@ async def update_subscription_tier( Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False ) - # Downgrade to FREE: cancel active Stripe subscription, then update the DB tier. + # Downgrade to FREE: schedule Stripe cancellation at period end so the user + # keeps their tier for the time they already paid for. The DB tier is NOT + # updated here when a subscription exists — the customer.subscription.deleted + # webhook fires at period end and downgrades to FREE then. + # Exception: if the user has no active Stripe subscription (e.g. admin-granted + # tier), cancel_stripe_subscription returns False and we update the DB tier + # immediately since no webhook will ever fire. + # When payment is disabled entirely, update the DB tier directly. if tier == SubscriptionTier.FREE: if payment_enabled: - await cancel_stripe_subscription(user_id) + try: + had_subscription = await cancel_stripe_subscription(user_id) + except stripe.StripeError as e: + # Log full Stripe error server-side but return a generic message + # to the client — raw Stripe errors can leak customer/sub IDs and + # infrastructure config details. + logger.exception( + "Stripe error cancelling subscription for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel your subscription right now. " + "Please try again or contact support." + ), + ) + if not had_subscription: + # No active Stripe subscription found — the user was on an + # admin-granted tier. Update DB immediately since the + # subscription.deleted webhook will never fire. + await set_subscription_tier(user_id, tier) + return SubscriptionCheckoutResponse(url="") await set_subscription_tier(user_id, tier) return SubscriptionCheckoutResponse(url="") - # Beta users (payment not enabled) → update tier directly without Stripe. + # Paid tier changes require payment to be enabled — block self-service upgrades + # when the flag is off. Admins use the /api/admin/ routes to set tiers directly. if not payment_enabled: - await set_subscription_tier(user_id, tier) + raise HTTPException( + status_code=422, + detail=f"Subscription not available for tier {tier}", + ) + + # No-op short-circuit: if the user is already on the requested paid tier, + # do NOT create a new Checkout Session. Without this guard, a duplicate + # request (double-click, retried POST, stale page) creates a second + # subscription for the same price; the user would be charged for both + # until `_cleanup_stale_subscriptions` runs from the resulting webhook — + # which only fires after the second charge has cleared. + if (user.subscription_tier or SubscriptionTier.FREE) == tier: return SubscriptionCheckoutResponse(url="") - # Paid upgrade → create Stripe Checkout Session. + # Paid→paid tier change: if the user already has a Stripe subscription, + # modify it in-place with proration instead of creating a new Checkout + # Session. This preserves remaining paid time and avoids double-charging. + # The customer.subscription.updated webhook fires and updates the DB tier. + current_tier = user.subscription_tier or SubscriptionTier.FREE + if current_tier in (SubscriptionTier.PRO, SubscriptionTier.BUSINESS): + try: + modified = await modify_stripe_subscription_for_tier(user_id, tier) + if modified: + return SubscriptionCheckoutResponse(url="") + # modify_stripe_subscription_for_tier returns False when no active + # Stripe subscription exists — i.e. the user has an admin-granted + # paid tier with no Stripe record. In that case, update the DB + # tier directly (same as the FREE-downgrade path for admin-granted + # users) rather than sending them through a new Checkout Session. + await set_subscription_tier(user_id, tier) + return SubscriptionCheckoutResponse(url="") + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) + except stripe.StripeError as e: + logger.exception( + "Stripe error modifying subscription for user %s: %s", user_id, e + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to update your subscription right now. " + "Please try again or contact support." + ), + ) + + # Paid upgrade from FREE → create Stripe Checkout Session. if not request.success_url or not request.cancel_url: raise HTTPException( status_code=422, detail="success_url and cancel_url are required for paid tier upgrades", ) + # Open-redirect protection: both URLs must point to the configured frontend + # origin, otherwise an attacker could use our Stripe integration as a + # redirector to arbitrary phishing sites. + # + # Fail early with a clear 503 if the server is misconfigured (neither + # frontend_base_url nor platform_base_url set), so operators get an + # actionable error instead of the misleading "must match the platform + # frontend origin" 422 that _validate_checkout_redirect_url would otherwise + # produce when `allowed` is empty. + if not (settings.config.frontend_base_url or settings.config.platform_base_url): + logger.error( + "update_subscription_tier: neither frontend_base_url nor " + "platform_base_url is configured; cannot validate checkout redirect URLs" + ) + raise HTTPException( + status_code=503, + detail=( + "Payment redirect URLs cannot be validated: " + "frontend_base_url or platform_base_url must be set on the server." + ), + ) + if not _validate_checkout_redirect_url( + request.success_url + ) or not _validate_checkout_redirect_url(request.cancel_url): + raise HTTPException( + status_code=422, + detail="success_url and cancel_url must match the platform frontend origin", + ) try: url = await create_subscription_checkout( user_id=user_id, @@ -791,8 +964,19 @@ async def update_subscription_tier( success_url=request.success_url, cancel_url=request.cancel_url, ) - except (ValueError, stripe.StripeError) as e: + except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) + except stripe.StripeError as e: + logger.exception( + "Stripe error creating checkout session for user %s: %s", user_id, e + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to start checkout right now. " + "Please try again or contact support." + ), + ) return SubscriptionCheckoutResponse(url=url) @@ -801,44 +985,78 @@ async def update_subscription_tier( path="/credits/stripe_webhook", summary="Handle Stripe webhooks", tags=["credits"] ) async def stripe_webhook(request: Request): + webhook_secret = settings.secrets.stripe_webhook_secret + if not webhook_secret: + # Guard: an empty secret allows HMAC forgery (attacker can compute a valid + # signature over the same empty key). Reject all webhook calls when unconfigured. + logger.error( + "stripe_webhook: STRIPE_WEBHOOK_SECRET is not configured — " + "rejecting request to prevent signature bypass" + ) + raise HTTPException(status_code=503, detail="Webhook not configured") + # Get the raw request body payload = await request.body() # Get the signature header sig_header = request.headers.get("stripe-signature") try: - event = stripe.Webhook.construct_event( - payload, sig_header, settings.secrets.stripe_webhook_secret - ) - except ValueError as e: + event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret) + except ValueError: # Invalid payload - raise HTTPException( - status_code=400, detail=f"Invalid payload: {str(e) or type(e).__name__}" - ) - except stripe.SignatureVerificationError as e: + raise HTTPException(status_code=400, detail="Invalid payload") + except stripe.SignatureVerificationError: # Invalid signature - raise HTTPException( - status_code=400, detail=f"Invalid signature: {str(e) or type(e).__name__}" + raise HTTPException(status_code=400, detail="Invalid signature") + + # Defensive payload extraction. A malformed payload (missing/non-dict + # `data.object`, missing `id`) would otherwise raise KeyError/TypeError + # AFTER signature verification — which Stripe interprets as a delivery + # failure and retries forever, while spamming Sentry with no useful info. + # Acknowledge with 200 and a warning so Stripe stops retrying. + event_type = event.get("type", "") + event_data = event.get("data") or {} + data_object = event_data.get("object") if isinstance(event_data, dict) else None + if not isinstance(data_object, dict): + logger.warning( + "stripe_webhook: %s missing or non-dict data.object; ignoring", + event_type, ) + return Response(status_code=200) - if ( - event["type"] == "checkout.session.completed" - or event["type"] == "checkout.session.async_payment_succeeded" + if event_type in ( + "checkout.session.completed", + "checkout.session.async_payment_succeeded", ): - await UserCredit().fulfill_checkout(session_id=event["data"]["object"]["id"]) + session_id = data_object.get("id") + if not session_id: + logger.warning( + "stripe_webhook: %s missing data.object.id; ignoring", event_type + ) + return Response(status_code=200) + await UserCredit().fulfill_checkout(session_id=session_id) - if event["type"] in ( + if event_type in ( "customer.subscription.created", "customer.subscription.updated", "customer.subscription.deleted", ): - await sync_subscription_from_stripe(event["data"]["object"]) + await sync_subscription_from_stripe(data_object) - if event["type"] == "charge.dispute.created": - await UserCredit().handle_dispute(event["data"]["object"]) + if event_type == "invoice.payment_failed": + await handle_subscription_payment_failure(data_object) - if event["type"] == "refund.created" or event["type"] == "charge.dispute.closed": - await UserCredit().deduct_credits(event["data"]["object"]) + # `handle_dispute` and `deduct_credits` expect Stripe SDK typed objects + # (Dispute/Refund). The Stripe webhook payload's `data.object` is a + # StripeObject (a dict subclass) carrying that runtime shape, so we cast + # to satisfy the type checker without changing runtime behaviour. + if event_type == "charge.dispute.created": + await UserCredit().handle_dispute(cast(stripe.Dispute, data_object)) + + if event_type == "refund.created" or event_type == "charge.dispute.closed": + await UserCredit().deduct_credits( + cast("stripe.Refund | stripe.Dispute", data_object) + ) return Response(status_code=200) diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 7becac185d..8543a03b69 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -106,7 +106,6 @@ class LlmModelMeta(EnumMeta): class LlmModel(str, Enum, metaclass=LlmModelMeta): - @classmethod def _missing_(cls, value: object) -> "LlmModel | None": """Handle provider-prefixed model names like 'anthropic/claude-sonnet-4-6'.""" @@ -203,6 +202,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): GROK_4 = "x-ai/grok-4" GROK_4_FAST = "x-ai/grok-4-fast" GROK_4_1_FAST = "x-ai/grok-4.1-fast" + GROK_4_20 = "x-ai/grok-4.20" + GROK_4_20_MULTI_AGENT = "x-ai/grok-4.20-multi-agent" GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1" KIMI_K2 = "moonshotai/kimi-k2" QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507" @@ -627,6 +628,18 @@ MODEL_METADATA = { LlmModel.GROK_4_1_FAST: ModelMetadata( "open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1 ), + LlmModel.GROK_4_20: ModelMetadata( + "open_router", 2000000, 100000, "Grok 4.20", "OpenRouter", "xAI", 3 + ), + LlmModel.GROK_4_20_MULTI_AGENT: ModelMetadata( + "open_router", + 2000000, + 100000, + "Grok 4.20 Multi-Agent", + "OpenRouter", + "xAI", + 3, + ), LlmModel.GROK_CODE_FAST_1: ModelMetadata( "open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1 ), @@ -987,7 +1000,6 @@ async def llm_call( reasoning=reasoning, ) elif provider == "anthropic": - an_tools = convert_openai_tool_fmt_to_anthropic(tools) # Cache tool definitions alongside the system prompt. # Placing cache_control on the last tool caches all tool schemas as a diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index dd6aa121b6..a2813ad881 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -67,11 +67,15 @@ from backend.copilot.transcript import ( STOP_REASON_END_TURN, STOP_REASON_TOOL_USE, TranscriptDownload, + detect_gap, download_transcript, + extract_context_messages, + strip_for_upload, upload_transcript, validate_transcript, ) from backend.copilot.transcript_builder import TranscriptBuilder +from backend.util import json as util_json from backend.util.exceptions import NotFoundError from backend.util.prompt import ( compress_context, @@ -699,81 +703,147 @@ async def _compress_session_messages( return messages -def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool: - """Return ``True`` when a download doesn't cover the current session. - - A transcript is stale when it has a known ``message_count`` and that - count doesn't reach ``session_msg_count - 1`` (i.e. the session has - already advanced beyond what the stored transcript captures). - Loading a stale transcript would silently drop intermediate turns, - so callers should treat stale as "skip load, skip upload". - - An unknown ``message_count`` (``0``) is treated as **not stale** - because older transcripts uploaded before msg_count tracking - existed must still be usable. - """ - if dl is None: - return False - if not dl.message_count: - return False - return dl.message_count < session_msg_count - 1 - - -def should_upload_transcript( - user_id: str | None, transcript_covers_prefix: bool -) -> bool: +def should_upload_transcript(user_id: str | None, upload_safe: bool) -> bool: """Return ``True`` when the caller should upload the final transcript. - Uploads require a logged-in user (for the storage key) *and* a - transcript that covered the session prefix when loaded — otherwise - we'd be overwriting a more complete version in storage with a - partial one built from just the current turn. + Uploads require a logged-in user (for the storage key) *and* a safe + upload signal from ``_load_prior_transcript`` — i.e. GCS does not hold a + newer version that we'd be overwriting. """ - return bool(user_id) and transcript_covers_prefix + return bool(user_id) and upload_safe + + +def _append_gap_to_builder( + gap: list[ChatMessage], + builder: TranscriptBuilder, +) -> None: + """Append gap messages from chat-db into the TranscriptBuilder. + + Converts ChatMessage (OpenAI format) to TranscriptBuilder entries + (Claude CLI JSONL format) so the uploaded transcript covers all turns. + + Pre-condition: ``gap`` always starts at a user or assistant boundary + (never mid-turn at a ``tool`` role), because ``detect_gap`` enforces + ``session_messages[wm-1].role == 'assistant'`` before returning a non-empty + gap. Any ``tool`` role messages within the gap always follow an assistant + entry that already exists in the builder or in the gap itself. + """ + for msg in gap: + if msg.role == "user": + builder.append_user(msg.content or "") + elif msg.role == "assistant": + content_blocks: list[dict] = [] + if msg.content: + content_blocks.append({"type": "text", "text": msg.content}) + if msg.tool_calls: + for tc in msg.tool_calls: + fn = tc.get("function", {}) if isinstance(tc, dict) else {} + input_data = util_json.loads(fn.get("arguments", "{}"), fallback={}) + content_blocks.append( + { + "type": "tool_use", + "id": tc.get("id", "") if isinstance(tc, dict) else "", + "name": fn.get("name", "unknown"), + "input": input_data, + } + ) + if not content_blocks: + # Fallback: ensure every assistant gap message produces an entry + # so the builder's entry count matches the gap length. + content_blocks.append({"type": "text", "text": ""}) + builder.append_assistant(content_blocks=content_blocks) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning( + "[Baseline] Skipping tool gap message with no tool_call_id" + ) async def _load_prior_transcript( user_id: str, session_id: str, - session_msg_count: int, + session_messages: list[ChatMessage], transcript_builder: TranscriptBuilder, -) -> bool: - """Download and load the prior transcript into ``transcript_builder``. +) -> tuple[bool, "TranscriptDownload | None"]: + """Download and load the prior CLI session into ``transcript_builder``. - Returns ``True`` when the loaded transcript fully covers the session - prefix; ``False`` otherwise (stale, missing, invalid, or download - error). Callers should suppress uploads when this returns ``False`` - to avoid overwriting a more complete version in storage. + Returns a tuple of (upload_safe, transcript_download): + - ``upload_safe`` is ``True`` when it is safe to upload at the end of this + turn. Upload is suppressed only for **download errors** (unknown GCS + state) — missing and invalid files return ``True`` because there is + nothing in GCS worth protecting against overwriting. + - ``transcript_download`` is a ``TranscriptDownload`` with str content + (pre-decoded and stripped) when available, or ``None`` when no valid + transcript could be loaded. Callers pass this to + ``extract_context_messages`` to build the LLM context. """ try: - dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]") - except Exception as e: - logger.warning("[Baseline] Transcript download failed: %s", e) - return False - - if dl is None: - logger.debug("[Baseline] No transcript available") - return False - - if not validate_transcript(dl.content): - logger.warning("[Baseline] Downloaded transcript but invalid") - return False - - if is_transcript_stale(dl, session_msg_count): - logger.warning( - "[Baseline] Transcript stale: covers %d of %d messages, skipping", - dl.message_count, - session_msg_count, + restore = await download_transcript( + user_id, session_id, log_prefix="[Baseline]" ) - return False + except Exception as e: + logger.warning("[Baseline] Session restore failed: %s", e) + # Unknown GCS state — be conservative, skip upload. + return False, None - transcript_builder.load_previous(dl.content, log_prefix="[Baseline]") + if restore is None: + logger.debug("[Baseline] No CLI session available — will upload fresh") + # Nothing in GCS to protect; allow upload so the first baseline turn + # writes the initial transcript snapshot. + return True, None + + content_bytes = restore.content + try: + raw_str = ( + content_bytes.decode("utf-8") + if isinstance(content_bytes, bytes) + else content_bytes + ) + except UnicodeDecodeError: + logger.warning("[Baseline] CLI session content is not valid UTF-8") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + stripped = strip_for_upload(raw_str) + if not validate_transcript(stripped): + logger.warning("[Baseline] CLI session content invalid after strip") + # Corrupt file in GCS; overwriting with a valid one is better. + return True, None + + transcript_builder.load_previous(stripped, log_prefix="[Baseline]") logger.info( - "[Baseline] Loaded transcript: %dB, msg_count=%d", - len(dl.content), - dl.message_count, + "[Baseline] Loaded CLI session: %dB, msg_count=%d", + len(content_bytes) if isinstance(content_bytes, bytes) else len(raw_str), + restore.message_count, ) - return True + + gap = detect_gap(restore, session_messages) + if gap: + _append_gap_to_builder(gap, transcript_builder) + logger.info( + "[Baseline] Filled gap: loaded %d transcript msgs + %d gap msgs from DB", + restore.message_count, + len(gap), + ) + + # Return a str-content version so extract_context_messages receives a + # pre-decoded, stripped transcript (avoids redundant decode + strip). + # TranscriptDownload.content is typed as bytes | str; we pass str here + # to avoid a redundant encode + decode round-trip. + str_restore = TranscriptDownload( + content=stripped, + message_count=restore.message_count, + mode=restore.mode, + ) + return True, str_restore async def _upload_final_transcript( @@ -807,10 +877,10 @@ async def _upload_final_transcript( upload_transcript( user_id=user_id, session_id=session_id, - content=content, + content=content.encode("utf-8"), message_count=session_msg_count, + mode="baseline", log_prefix="[Baseline]", - skip_strip=True, ) ) _background_tasks.add(upload_task) @@ -897,7 +967,7 @@ async def stream_chat_completion_baseline( # --- Transcript support (feature parity with SDK path) --- transcript_builder = TranscriptBuilder() - transcript_covers_prefix = True + transcript_upload_safe = True # Build system prompt only on the first turn to avoid mid-conversation # changes from concurrent chats updating business understanding. @@ -914,15 +984,16 @@ async def stream_chat_completion_baseline( # Run download + prompt build concurrently — both are independent I/O # on the request critical path. + transcript_download: TranscriptDownload | None = None if user_id and len(session.messages) > 1: ( - transcript_covers_prefix, + (transcript_upload_safe, transcript_download), (base_system_prompt, understanding), ) = await asyncio.gather( _load_prior_transcript( user_id=user_id, session_id=session_id, - session_msg_count=len(session.messages), + session_messages=session.messages, transcript_builder=transcript_builder, ), prompt_task, @@ -962,9 +1033,14 @@ async def stream_chat_completion_baseline( warm_ctx = await fetch_warm_context(user_id, message or "") - # Compress context if approaching the model's token limit + # Context path: transcript content (compacted, isCompactSummary preserved) + + # gap (DB messages after watermark) + current user turn. + # This avoids re-reading the full session history from DB on every turn. + # See extract_context_messages() in transcript.py for the shared primitive. + prior_context = extract_context_messages(transcript_download, session.messages) messages_for_context = await _compress_session_messages( - session.messages, model=active_model + prior_context + ([session.messages[-1]] if session.messages else []), + model=active_model, ) # Build OpenAI message list from session history. @@ -1308,7 +1384,7 @@ async def stream_chat_completion_baseline( stop_reason=STOP_REASON_END_TURN, ) - if user_id and should_upload_transcript(user_id, transcript_covers_prefix): + if user_id and should_upload_transcript(user_id, transcript_upload_safe): await _upload_final_transcript( user_id=user_id, session_id=session_id, diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index baeb3e3648..4247c76c19 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -1,7 +1,7 @@ """Integration tests for baseline transcript flow. -Exercises the real helpers in ``baseline/service.py`` that download, -validate, load, append to, backfill, and upload the transcript. +Exercises the real helpers in ``baseline/service.py`` that restore, +validate, load, append to, backfill, and upload the CLI session. Storage is mocked via ``download_transcript`` / ``upload_transcript`` patches; no network access is required. """ @@ -12,13 +12,14 @@ from unittest.mock import AsyncMock, patch import pytest from backend.copilot.baseline.service import ( + _append_gap_to_builder, _load_prior_transcript, _record_turn_to_transcript, _resolve_baseline_model, _upload_final_transcript, - is_transcript_stale, should_upload_transcript, ) +from backend.copilot.model import ChatMessage from backend.copilot.service import config from backend.copilot.transcript import ( STOP_REASON_END_TURN, @@ -54,6 +55,13 @@ def _make_transcript_content(*roles: str) -> str: return "\n".join(lines) + "\n" +def _make_session_messages(*roles: str) -> list[ChatMessage]: + """Build a list of ChatMessage objects matching the given roles.""" + return [ + ChatMessage(role=r, content=f"{r} message {i}") for i, r in enumerate(roles) + ] + + class TestResolveBaselineModel: """Model selection honours the per-request mode.""" @@ -73,87 +81,102 @@ class TestResolveBaselineModel: class TestLoadPriorTranscript: - """``_load_prior_transcript`` wraps the download + validate + load flow.""" + """``_load_prior_transcript`` wraps the CLI session restore + validate + load flow.""" @pytest.mark.asyncio async def test_loads_fresh_transcript(self): builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=content, message_count=2) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True + assert dl is not None + assert dl.message_count == 2 assert builder.entry_count == 2 assert builder.last_entry_type == "assistant" @pytest.mark.asyncio - async def test_rejects_stale_transcript(self): - """msg_count strictly less than session-1 is treated as stale.""" + async def test_fills_gap_when_transcript_is_behind(self): + """When transcript covers fewer messages than session, gap is filled from DB.""" builder = TranscriptBuilder() content = _make_transcript_content("user", "assistant") - # session has 6 messages, transcript only covers 2 → stale. - download = TranscriptDownload(content=content, message_count=2) + # transcript covers 2 messages, session has 4 (plus current user turn = 5) + restore = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="baseline" + ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=6, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - assert builder.is_empty + assert covers is True + assert dl is not None + # 2 from transcript + 2 gap messages (user+assistant at positions 2,3) + assert builder.entry_count == 4 @pytest.mark.asyncio - async def test_missing_transcript_returns_false(self): + async def test_missing_transcript_allows_upload(self): + """Nothing in GCS → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", new=AsyncMock(return_value=None), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio - async def test_invalid_transcript_returns_false(self): + async def test_invalid_transcript_allows_upload(self): + """Corrupt file in GCS → overwriting with a valid one is better.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content='{"type":"progress","uuid":"a"}\n', + restore = TranscriptDownload( + content=b'{"type":"progress","uuid":"a"}\n', message_count=1, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) - assert covers is False + assert upload_safe is True + assert dl is None assert builder.is_empty @pytest.mark.asyncio @@ -163,36 +186,39 @@ class TestLoadPriorTranscript: "backend.copilot.baseline.service.download_transcript", new=AsyncMock(side_effect=RuntimeError("boom")), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=2, + session_messages=_make_session_messages("user", "assistant"), transcript_builder=builder, ) assert covers is False + assert dl is None assert builder.is_empty @pytest.mark.asyncio async def test_zero_message_count_not_stale(self): - """When msg_count is 0 (unknown), staleness check is skipped.""" + """When msg_count is 0 (unknown), gap detection is skipped.""" builder = TranscriptBuilder() - download = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + restore = TranscriptDownload( + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=0, + mode="sdk", ) with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=20, + session_messages=_make_session_messages(*["user"] * 20), transcript_builder=builder, ) assert covers is True + assert dl is not None assert builder.entry_count == 2 @@ -227,7 +253,7 @@ class TestUploadFinalTranscript: assert call_kwargs["user_id"] == "user-1" assert call_kwargs["session_id"] == "session-1" assert call_kwargs["message_count"] == 2 - assert "hello" in call_kwargs["content"] + assert b"hello" in call_kwargs["content"] @pytest.mark.asyncio async def test_skips_upload_when_builder_empty(self): @@ -374,17 +400,19 @@ class TestRoundTrip: @pytest.mark.asyncio async def test_full_round_trip(self): prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -424,11 +452,11 @@ class TestRoundTrip: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "new question" in uploaded - assert "new answer" in uploaded + assert b"new question" in uploaded + assert b"new answer" in uploaded # Original content preserved in the round trip. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio async def test_backfill_append_guard(self): @@ -459,36 +487,6 @@ class TestRoundTrip: assert builder.entry_count == initial_count -class TestIsTranscriptStale: - """``is_transcript_stale`` gates prior-transcript loading.""" - - def test_none_download_is_not_stale(self): - assert is_transcript_stale(None, session_msg_count=5) is False - - def test_zero_message_count_is_not_stale(self): - """Legacy transcripts without msg_count tracking must remain usable.""" - dl = TranscriptDownload(content="", message_count=0) - assert is_transcript_stale(dl, session_msg_count=20) is False - - def test_stale_when_covers_less_than_prefix(self): - dl = TranscriptDownload(content="", message_count=2) - # session has 6 messages; transcript must cover at least 5 (6-1). - assert is_transcript_stale(dl, session_msg_count=6) is True - - def test_fresh_when_covers_full_prefix(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_fresh_when_exceeds_prefix(self): - """Race: transcript ahead of session count is still acceptable.""" - dl = TranscriptDownload(content="", message_count=10) - assert is_transcript_stale(dl, session_msg_count=6) is False - - def test_boundary_equal_to_prefix_minus_one(self): - dl = TranscriptDownload(content="", message_count=5) - assert is_transcript_stale(dl, session_msg_count=6) is False - - class TestShouldUploadTranscript: """``should_upload_transcript`` gates the final upload.""" @@ -510,7 +508,7 @@ class TestShouldUploadTranscript: class TestTranscriptLifecycle: - """End-to-end: download → validate → build → upload. + """End-to-end: restore → validate → build → upload. Simulates the full transcript lifecycle inside ``stream_chat_completion_baseline`` by mocking the storage layer and @@ -519,27 +517,29 @@ class TestTranscriptLifecycle: @pytest.mark.asyncio async def test_full_lifecycle_happy_path(self): - """Fresh download, append a turn, upload covers the session.""" + """Fresh restore, append a turn, upload covers the session.""" builder = TranscriptBuilder() prior = _make_transcript_content("user", "assistant") - download = TranscriptDownload(content=prior, message_count=2) + restore = TranscriptDownload( + content=prior.encode("utf-8"), message_count=2, mode="sdk" + ) upload_mock = AsyncMock(return_value=None) with ( patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ), patch( "backend.copilot.baseline.service.upload_transcript", new=upload_mock, ), ): - # --- 1. Download & load prior transcript --- - covers = await _load_prior_transcript( + # --- 1. Restore & load prior session --- + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, + session_messages=_make_session_messages("user", "assistant", "user"), transcript_builder=builder, ) assert covers is True @@ -559,10 +559,7 @@ class TestTranscriptLifecycle: # --- 3. Gate + upload --- assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is True + should_upload_transcript(user_id="user-1", upload_safe=covers) is True ) await _upload_final_transcript( user_id="user-1", @@ -574,20 +571,21 @@ class TestTranscriptLifecycle: upload_mock.assert_awaited_once() assert upload_mock.await_args is not None uploaded = upload_mock.await_args.kwargs["content"] - assert "follow-up question" in uploaded - assert "follow-up answer" in uploaded + assert b"follow-up question" in uploaded + assert b"follow-up answer" in uploaded # Original prior-turn content preserved. - assert "user message 0" in uploaded - assert "assistant message 1" in uploaded + assert b"user message 0" in uploaded + assert b"assistant message 1" in uploaded @pytest.mark.asyncio - async def test_lifecycle_stale_download_suppresses_upload(self): - """Stale download → covers=False → upload must be skipped.""" + async def test_lifecycle_stale_download_fills_gap(self): + """When transcript covers fewer messages, gap is filled rather than rejected.""" builder = TranscriptBuilder() - # session has 10 msgs but stored transcript only covers 2 → stale. + # session has 5 msgs but stored transcript only covers 2 → gap filled. stale = TranscriptDownload( - content=_make_transcript_content("user", "assistant"), + content=_make_transcript_content("user", "assistant").encode("utf-8"), message_count=2, + mode="baseline", ) upload_mock = AsyncMock(return_value=None) @@ -601,20 +599,18 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + covers, _ = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=_make_session_messages( + "user", "assistant", "user", "assistant", "user" + ), transcript_builder=builder, ) - assert covers is False - # The caller's gate mirrors the production path. - assert ( - should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers) - is False - ) - upload_mock.assert_not_awaited() + assert covers is True + # Gap was filled: 2 from transcript + 2 gap messages + assert builder.entry_count == 4 @pytest.mark.asyncio async def test_lifecycle_anonymous_user_skips_upload(self): @@ -627,15 +623,11 @@ class TestTranscriptLifecycle: stop_reason=STOP_REASON_END_TURN, ) - assert ( - should_upload_transcript(user_id=None, transcript_covers_prefix=True) - is False - ) + assert should_upload_transcript(user_id=None, upload_safe=True) is False @pytest.mark.asyncio async def test_lifecycle_missing_download_still_uploads_new_content(self): - """No prior transcript → covers defaults to True in the service, - new turn should upload cleanly.""" + """No prior session → upload is safe; the turn writes the first snapshot.""" builder = TranscriptBuilder() upload_mock = AsyncMock(return_value=None) with ( @@ -648,20 +640,117 @@ class TestTranscriptLifecycle: new=upload_mock, ), ): - covers = await _load_prior_transcript( + upload_safe, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=1, + session_messages=_make_session_messages("user"), transcript_builder=builder, ) - # No download: covers is False, so the production path would - # skip upload. This protects against overwriting a future - # more-complete transcript with a single-turn snapshot. - assert covers is False + # Nothing in GCS → upload is safe so the first baseline turn + # can write the initial transcript snapshot. + assert upload_safe is True + assert dl is None assert ( - should_upload_transcript( - user_id="user-1", transcript_covers_prefix=covers - ) - is False + should_upload_transcript(user_id="user-1", upload_safe=upload_safe) + is True ) - upload_mock.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# _append_gap_to_builder +# --------------------------------------------------------------------------- + + +class TestAppendGapToBuilder: + """``_append_gap_to_builder`` converts ChatMessage objects to TranscriptBuilder entries.""" + + def test_user_message_appended(self): + builder = TranscriptBuilder() + msgs = [ChatMessage(role="user", content="hello")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + assert builder.last_entry_type == "user" + + def test_assistant_text_message_appended(self): + builder = TranscriptBuilder() + msgs = [ + ChatMessage(role="user", content="q"), + ChatMessage(role="assistant", content="answer"), + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 2 + assert builder.last_entry_type == "assistant" + assert "answer" in builder.to_jsonl() + + def test_assistant_with_tool_calls_appended(self): + """Assistant tool_calls are recorded as tool_use blocks in the transcript.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-1", + "type": "function", + "function": {"name": "my_tool", "arguments": '{"key":"val"}'}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "tool_use" in jsonl + assert "my_tool" in jsonl + assert "tc-1" in jsonl + + def test_assistant_invalid_json_args_uses_empty_dict(self): + """Malformed JSON in tool_call arguments falls back to {}.""" + builder = TranscriptBuilder() + tool_call = { + "id": "tc-bad", + "type": "function", + "function": {"name": "bad_tool", "arguments": "not-json"}, + } + msgs = [ChatMessage(role="assistant", content=None, tool_calls=[tool_call])] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert '"input":{}' in jsonl + + def test_assistant_empty_content_and_no_tools_uses_fallback(self): + """Assistant with no content and no tool_calls gets a fallback empty text block.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="assistant", content=None)] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "text" in jsonl + + def test_tool_role_with_tool_call_id_appended(self): + """Tool result messages are appended when tool_call_id is set.""" + builder = TranscriptBuilder() + # Need a preceding assistant tool_use entry + builder.append_user("use tool") + builder.append_assistant( + content_blocks=[ + {"type": "tool_use", "id": "tc-1", "name": "my_tool", "input": {}} + ] + ) + msgs = [ChatMessage(role="tool", tool_call_id="tc-1", content="result")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 3 + assert "tool_result" in builder.to_jsonl() + + def test_tool_role_without_tool_call_id_skipped(self): + """Tool messages without tool_call_id are silently skipped.""" + builder = TranscriptBuilder() + msgs = [ChatMessage(role="tool", tool_call_id=None, content="orphan")] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 0 + + def test_tool_call_missing_function_key_uses_unknown_name(self): + """A tool_call dict with no 'function' key uses 'unknown' as the tool name.""" + builder = TranscriptBuilder() + # Tool call dict exists but 'function' sub-dict is missing entirely + msgs = [ + ChatMessage(role="assistant", content=None, tool_calls=[{"id": "tc-x"}]) + ] + _append_gap_to_builder(msgs, builder) + assert builder.entry_count == 1 + jsonl = builder.to_jsonl() + assert "unknown" in jsonl diff --git a/autogpt_platform/backend/backend/copilot/context.py b/autogpt_platform/backend/backend/copilot/context.py index 895aa6c4a1..7a22f02cb2 100644 --- a/autogpt_platform/backend/backend/copilot/context.py +++ b/autogpt_platform/backend/backend/copilot/context.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: # Allowed base directory for the Read tool. Public so service.py can use it # for sweep operations without depending on a private implementation detail. # Respects CLAUDE_CONFIG_DIR env var, consistent with transcript.py's -# _projects_base() function. +# projects_base() function. _config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") SDK_PROJECTS_DIR = os.path.realpath(os.path.join(_config_dir, "projects")) diff --git a/autogpt_platform/backend/backend/copilot/db.py b/autogpt_platform/backend/backend/copilot/db.py index b85e08606c..bc4964ec35 100644 --- a/autogpt_platform/backend/backend/copilot/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -10,9 +10,11 @@ from prisma.models import ChatMessage as PrismaChatMessage from prisma.models import ChatSession as PrismaChatSession from prisma.types import ( ChatMessageCreateInput, + ChatMessageWhereInput, ChatSessionCreateInput, ChatSessionUpdateInput, ChatSessionWhereInput, + FindManyChatMessageArgsFromChatSession, ) from pydantic import BaseModel @@ -30,6 +32,8 @@ from .model import get_chat_session as get_chat_session_cached logger = logging.getLogger(__name__) +_BOUNDARY_SCAN_LIMIT = 10 + class PaginatedMessages(BaseModel): """Result of a paginated message query.""" @@ -37,6 +41,7 @@ class PaginatedMessages(BaseModel): messages: list[ChatMessage] has_more: bool oldest_sequence: int | None + newest_sequence: int | None session: ChatSessionInfo @@ -61,32 +66,48 @@ async def get_chat_messages_paginated( session_id: str, limit: int = 50, before_sequence: int | None = None, + after_sequence: int | None = None, + from_start: bool = False, user_id: str | None = None, ) -> PaginatedMessages | None: - """Get paginated messages for a session, newest first. + """Get paginated messages for a session. - Verifies session existence (and ownership when ``user_id`` is provided) - in parallel with the message query. Returns ``None`` when the session - is not found or does not belong to the user. + Three modes: - Args: - session_id: The chat session ID. - limit: Max messages to return. - before_sequence: Cursor — return messages with sequence < this value. - user_id: If provided, filters via ``Session.userId`` so only the - session owner's messages are returned (acts as an ownership guard). + - ``before_sequence`` set: backward pagination (DESC), returns messages + with sequence < ``before_sequence``. Used for active sessions or manual + backward navigation. + - ``from_start=True`` or ``after_sequence`` set: forward pagination (ASC). + Returns messages from sequence 0 (``from_start``) or after + ``after_sequence``. Used on initial load of completed sessions and for + loading subsequent forward pages. + - Both cursors ``None`` and ``from_start=False``: newest-first (DESC + without filter). Used for active sessions on initial load. + + Verifies session existence (and ownership when ``user_id`` is provided). + Returns ``None`` when the session is not found or does not belong to the + user. """ # Build session-existence / ownership check session_where: ChatSessionWhereInput = {"id": session_id} if user_id is not None: session_where["userId"] = user_id - # Build message include — fetch paginated messages in the same query - msg_include: dict[str, Any] = { - "order_by": {"sequence": "desc"}, + forward = from_start or after_sequence is not None + + # Build message include — fetch paginated messages in the same query. + # Note: when both from_start=True and after_sequence is not None, the + # after_sequence filter takes precedence (the elif branch below is skipped). + # This combination is not reachable via the HTTP route (mutual exclusion is + # enforced there), so we rely on the documented priority here without an + # additional assertion. + msg_include: FindManyChatMessageArgsFromChatSession = { + "order_by": {"sequence": "asc" if forward else "desc"}, "take": limit + 1, } - if before_sequence is not None: + if after_sequence is not None: + msg_include["where"] = {"sequence": {"gt": after_sequence}} + elif before_sequence is not None: msg_include["where"] = {"sequence": {"lt": before_sequence}} # Single query: session existence/ownership + paginated messages @@ -104,57 +125,96 @@ async def get_chat_messages_paginated( has_more = len(results) > limit results = results[:limit] - # Reverse to ascending order - results.reverse() + if not forward: + # Backward mode: DB returned DESC; reverse to ascending order. + results.reverse() - # Tool-call boundary fix: if the oldest message is a tool message, - # expand backward to include the preceding assistant message that - # owns the tool_calls, so convertChatSessionMessagesToUiMessages - # can pair them correctly. - _BOUNDARY_SCAN_LIMIT = 10 - if results and results[0].role == "tool": - boundary_where: dict[str, Any] = { - "sessionId": session_id, - "sequence": {"lt": results[0].sequence}, - } - if user_id is not None: - boundary_where["Session"] = {"is": {"userId": user_id}} - extra = await PrismaChatMessage.prisma().find_many( - where=boundary_where, - order={"sequence": "desc"}, - take=_BOUNDARY_SCAN_LIMIT, - ) - # Find the first non-tool message (should be the assistant) - boundary_msgs = [] - found_owner = False - for msg in extra: - boundary_msgs.append(msg) - if msg.role != "tool": - found_owner = True - break - boundary_msgs.reverse() - if not found_owner: - logger.warning( - "Boundary expansion did not find owning assistant message " - "for session=%s before sequence=%s (%d msgs scanned)", - session_id, - results[0].sequence, - len(extra), + # Tool-call boundary fix: if the oldest message is a tool message, + # expand backward to include the preceding assistant message that + # owns the tool_calls, so convertChatSessionMessagesToUiMessages + # can pair them correctly. + if results and results[0].role == "tool": + boundary_where: ChatMessageWhereInput = { + "sessionId": session_id, + "sequence": {"lt": results[0].sequence}, + } + if user_id is not None: + boundary_where["Session"] = {"is": {"userId": user_id}} + extra = await PrismaChatMessage.prisma().find_many( + where=boundary_where, + order={"sequence": "desc"}, + take=_BOUNDARY_SCAN_LIMIT, ) - if boundary_msgs: - results = boundary_msgs + results - # Only mark has_more if the expanded boundary isn't the - # very start of the conversation (sequence 0). - if boundary_msgs[0].sequence > 0: + # Find the first non-tool message (should be the assistant) + boundary_msgs = [] + found_owner = False + for msg in extra: + boundary_msgs.append(msg) + if msg.role != "tool": + found_owner = True + break + boundary_msgs.reverse() + if not found_owner: + logger.warning( + "Boundary expansion did not find owning assistant message " + "for session=%s before sequence=%s (%d msgs scanned)", + session_id, + results[0].sequence, + len(extra), + ) + if boundary_msgs: + results = boundary_msgs + results + # Only mark has_more if the expanded boundary isn't the + # very start of the conversation (sequence 0). + if boundary_msgs[0].sequence > 0: + has_more = True + else: + # Forward mode: DB returned ASC. + # Tool-call tail boundary fix: if the last message in this page is a + # tool message, the NEXT forward page would start after it and begin + # mid-tool-group — the owning assistant message is on this page but + # the following tool results are on the next page. + # Trim the current page so it ends on the owning assistant message, + # which keeps tool groups intact across page boundaries. + if results and results[-1].role == "tool": + # Walk backward through results to find the last non-tool message. + trim_idx = len(results) - 1 + while trim_idx >= 0 and results[trim_idx].role == "tool": + trim_idx -= 1 + + if trim_idx >= 0: + # Trim results so the page ends at the owning assistant. + # Mark has_more=True so the client knows to fetch the rest. + results = results[: trim_idx + 1] has_more = True + else: + # Entire page is tool messages with no visible owner — log and + # keep as-is so the caller is not stuck with an empty page. + logger.warning( + "Forward tail boundary: entire page is tool messages " + "for session=%s, no owning assistant found (%d msgs)", + session_id, + len(results), + ) messages = [ChatMessage.from_db(m) for m in results] - oldest_sequence = messages[0].sequence if messages else None + # oldest_sequence is only meaningful in backward mode (used as backward + # pagination cursor). In forward mode the page always starts near seq 0 + # and clients should use newest_sequence as the forward cursor instead. + # Return None in forward mode so clients don't accidentally treat it as a + # backward cursor on a forward-paginated session. + oldest_sequence = messages[0].sequence if (messages and not forward) else None + # newest_sequence is only meaningful in forward mode; in backward mode it + # points to the last message of the page (not the session's newest message) + # which is not a valid forward cursor. Return None in backward mode so + # clients don't accidentally use it as one. + newest_sequence = messages[-1].sequence if (messages and forward) else None return PaginatedMessages( messages=messages, has_more=has_more, oldest_sequence=oldest_sequence, + newest_sequence=newest_sequence, session=session_info, ) diff --git a/autogpt_platform/backend/backend/copilot/db_test.py b/autogpt_platform/backend/backend/copilot/db_test.py index a2eb050bc4..f9e7ad515f 100644 --- a/autogpt_platform/backend/backend/copilot/db_test.py +++ b/autogpt_platform/backend/backend/copilot/db_test.py @@ -175,6 +175,187 @@ async def test_no_where_on_messages_without_before_sequence( assert "where" not in include["Messages"] +# ---------- Forward pagination (from_start / after_sequence) ---------- + + +@pytest.mark.asyncio +async def test_from_start_uses_asc_order_no_where( + mock_db: tuple[AsyncMock, AsyncMock], +): + """from_start=True queries messages in ASC order with no where filter.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(0), _make_msg(1), _make_msg(2)], + ) + + await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True) + + call_kwargs = find_first.call_args + include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include") + assert include["Messages"]["order_by"] == {"sequence": "asc"} + assert "where" not in include["Messages"] + + +@pytest.mark.asyncio +async def test_from_start_returns_messages_ascending( + mock_db: tuple[AsyncMock, AsyncMock], +): + """from_start=True returns messages in ascending sequence order.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(0), _make_msg(1), _make_msg(2)], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True) + + assert page is not None + assert [m.sequence for m in page.messages] == [0, 1, 2] + assert ( + page.oldest_sequence is None + ) # None in forward mode — not a valid backward cursor + assert page.newest_sequence == 2 + assert page.has_more is False + + +@pytest.mark.asyncio +async def test_from_start_has_more_when_results_exceed_limit( + mock_db: tuple[AsyncMock, AsyncMock], +): + """from_start=True sets has_more when DB returns more than limit items.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(0), _make_msg(1), _make_msg(2)], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=2, from_start=True) + + assert page is not None + assert page.has_more is True + assert [m.sequence for m in page.messages] == [0, 1] + assert page.newest_sequence == 1 + + +@pytest.mark.asyncio +async def test_after_sequence_uses_gt_filter_asc_order( + mock_db: tuple[AsyncMock, AsyncMock], +): + """after_sequence adds a sequence > N where clause and uses ASC order.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(11), _make_msg(12)], + ) + + await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10) + + call_kwargs = find_first.call_args + include = call_kwargs.kwargs.get("include") or call_kwargs[1].get("include") + assert include["Messages"]["order_by"] == {"sequence": "asc"} + assert include["Messages"]["where"] == {"sequence": {"gt": 10}} + + +@pytest.mark.asyncio +async def test_after_sequence_returns_messages_in_order( + mock_db: tuple[AsyncMock, AsyncMock], +): + """after_sequence returns only messages with sequence > cursor, ascending.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(11), _make_msg(12), _make_msg(13)], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=50, after_sequence=10) + + assert page is not None + assert [m.sequence for m in page.messages] == [11, 12, 13] + assert ( + page.oldest_sequence is None + ) # None in forward mode — not a valid backward cursor + assert page.newest_sequence == 13 + assert page.has_more is False + + +@pytest.mark.asyncio +async def test_newest_sequence_none_for_backward_mode( + mock_db: tuple[AsyncMock, AsyncMock], +): + """newest_sequence is None in backward mode — it is not a valid forward cursor.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(5), _make_msg(4), _make_msg(3)], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=50) + + assert page is not None + assert page.newest_sequence is None + assert page.oldest_sequence == 3 + + +@pytest.mark.asyncio +async def test_forward_mode_no_boundary_expansion( + mock_db: tuple[AsyncMock, AsyncMock], +): + """Forward pagination never triggers backward boundary expansion.""" + find_first, find_many = mock_db + find_first.return_value = _make_session( + messages=[_make_msg(0, role="tool"), _make_msg(1, role="tool")], + ) + + await get_chat_messages_paginated(SESSION_ID, limit=50, from_start=True) + + assert find_many.call_count == 0 + + +@pytest.mark.asyncio +async def test_forward_tail_boundary_trims_trailing_tool_messages( + mock_db: tuple[AsyncMock, AsyncMock], +): + """Forward pages that end with tool messages are trimmed to the owning + assistant so the next after_sequence page doesn't start mid-tool-group.""" + find_first, _ = mock_db + # DB returns 4 messages ASC: assistant at 0, tool at 1, tool at 2, tool at 3 + find_first.return_value = _make_session( + messages=[ + _make_msg(0, role="assistant"), + _make_msg(1, role="tool"), + _make_msg(2, role="tool"), + _make_msg(3, role="tool"), + ], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True) + + assert page is not None + # Page should be trimmed to end at the assistant message + assert [m.sequence for m in page.messages] == [0] + assert page.newest_sequence == 0 + # has_more must be True so the client fetches the tool messages on next page + assert page.has_more is True + + +@pytest.mark.asyncio +async def test_forward_tail_boundary_no_trim_when_last_not_tool( + mock_db: tuple[AsyncMock, AsyncMock], +): + """Forward pages that end with a non-tool message are not trimmed.""" + find_first, _ = mock_db + find_first.return_value = _make_session( + messages=[ + _make_msg(0, role="user"), + _make_msg(1, role="assistant"), + _make_msg(2, role="tool"), + _make_msg(3, role="assistant"), + ], + ) + + page = await get_chat_messages_paginated(SESSION_ID, limit=10, from_start=True) + + assert page is not None + assert [m.sequence for m in page.messages] == [0, 1, 2, 3] + assert page.newest_sequence == 3 + assert page.has_more is False + + @pytest.mark.asyncio async def test_user_id_filter_applied_to_session_where( mock_db: tuple[AsyncMock, AsyncMock], diff --git a/autogpt_platform/backend/backend/copilot/message_dedup.py b/autogpt_platform/backend/backend/copilot/message_dedup.py deleted file mode 100644 index 2af13b559a..0000000000 --- a/autogpt_platform/backend/backend/copilot/message_dedup.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Per-request idempotency lock for the /stream endpoint. - -Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s -rolling-deploy retries, nginx upstream retries, rapid double-clicks). - -Lifecycle ---------- -1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids) - and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or - ``None`` when the key already exists (duplicate request). -2. ``release()`` — deletes the key. Must be called on turn completion or turn - error so the next legitimate send is never blocked. -3. On client disconnect (``GeneratorExit``) the lock must NOT be released — - the backend turn is still running, and releasing would reopen the duplicate - window for infra-level retries. The 30 s TTL is the safety net. -""" - -import hashlib -import logging - -from backend.data.redis_client import get_redis_async - -logger = logging.getLogger(__name__) - -_KEY_PREFIX = "chat:msg_dedup" -_TTL_SECONDS = 30 - - -class _DedupLock: - def __init__(self, key: str, redis) -> None: - self._key = key - self._redis = redis - - async def release(self) -> None: - """Best-effort key deletion. The TTL handles failures silently.""" - try: - await self._redis.delete(self._key) - except Exception: - pass - - -async def acquire_dedup_lock( - session_id: str, - message: str | None, - file_ids: list[str] | None, -) -> _DedupLock | None: - """Acquire the idempotency lock for this (session, message, files) tuple. - - Returns a ``_DedupLock`` when the lock is freshly acquired (first request). - Returns ``None`` when a duplicate is detected (lock already held). - Returns ``None`` when there is nothing to deduplicate (no message, no files). - """ - if not message and not file_ids: - return None - - sorted_ids = ":".join(sorted(file_ids or [])) - content_hash = hashlib.sha256( - f"{session_id}:{message or ''}:{sorted_ids}".encode() - ).hexdigest()[:16] - key = f"{_KEY_PREFIX}:{session_id}:{content_hash}" - - redis = await get_redis_async() - acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True) - if not acquired: - logger.warning( - f"[STREAM] Duplicate user message blocked for session {session_id}, " - f"hash={content_hash} — returning empty SSE", - ) - return None - - return _DedupLock(key, redis) diff --git a/autogpt_platform/backend/backend/copilot/message_dedup_test.py b/autogpt_platform/backend/backend/copilot/message_dedup_test.py deleted file mode 100644 index 935ddd36b6..0000000000 --- a/autogpt_platform/backend/backend/copilot/message_dedup_test.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Unit tests for backend.copilot.message_dedup.""" - -from unittest.mock import AsyncMock - -import pytest -import pytest_mock - -from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock - - -def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns): - mock_redis = AsyncMock() - mock_redis.set = AsyncMock(return_value=set_returns) - mocker.patch( - "backend.copilot.message_dedup.get_redis_async", - new_callable=AsyncMock, - return_value=mock_redis, - ) - return mock_redis - - -@pytest.mark.asyncio -async def test_acquire_returns_none_when_no_message_no_files( - mocker: pytest_mock.MockerFixture, -) -> None: - """Nothing to deduplicate — no Redis call made, None returned.""" - mock_redis = _patch_redis(mocker, set_returns=True) - result = await acquire_dedup_lock("sess-1", None, None) - assert result is None - mock_redis.set.assert_not_called() - - -@pytest.mark.asyncio -async def test_acquire_returns_lock_on_first_request( - mocker: pytest_mock.MockerFixture, -) -> None: - """First request acquires the lock and returns a _DedupLock.""" - mock_redis = _patch_redis(mocker, set_returns=True) - lock = await acquire_dedup_lock("sess-1", "hello", None) - assert lock is not None - mock_redis.set.assert_called_once() - key_arg = mock_redis.set.call_args.args[0] - assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:") - - -@pytest.mark.asyncio -async def test_acquire_returns_none_on_duplicate( - mocker: pytest_mock.MockerFixture, -) -> None: - """Duplicate request (NX fails) returns None to signal the caller.""" - _patch_redis(mocker, set_returns=None) - result = await acquire_dedup_lock("sess-1", "hello", None) - assert result is None - - -@pytest.mark.asyncio -async def test_acquire_key_stable_across_file_order( - mocker: pytest_mock.MockerFixture, -) -> None: - """File IDs are sorted before hashing so order doesn't affect the key.""" - mock_redis_1 = _patch_redis(mocker, set_returns=True) - await acquire_dedup_lock("sess-1", "msg", ["b", "a"]) - key_ab = mock_redis_1.set.call_args.args[0] - - mock_redis_2 = _patch_redis(mocker, set_returns=True) - await acquire_dedup_lock("sess-1", "msg", ["a", "b"]) - key_ba = mock_redis_2.set.call_args.args[0] - - assert key_ab == key_ba - - -@pytest.mark.asyncio -async def test_release_deletes_key( - mocker: pytest_mock.MockerFixture, -) -> None: - """release() calls Redis delete exactly once.""" - mock_redis = _patch_redis(mocker, set_returns=True) - lock = await acquire_dedup_lock("sess-1", "hello", None) - assert lock is not None - await lock.release() - mock_redis.delete.assert_called_once() - - -@pytest.mark.asyncio -async def test_release_swallows_redis_error( - mocker: pytest_mock.MockerFixture, -) -> None: - """release() must not raise even when Redis delete fails.""" - mock_redis = _patch_redis(mocker, set_returns=True) - mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down")) - lock = await acquire_dedup_lock("sess-1", "hello", None) - assert lock is not None - await lock.release() # must not raise - mock_redis.delete.assert_called_once() diff --git a/autogpt_platform/backend/backend/copilot/model.py b/autogpt_platform/backend/backend/copilot/model.py index fdde7fcddf..436630c90e 100644 --- a/autogpt_platform/backend/backend/copilot/model.py +++ b/autogpt_platform/backend/backend/copilot/model.py @@ -1,9 +1,8 @@ -import asyncio import logging import uuid +from contextlib import asynccontextmanager from datetime import UTC, datetime -from typing import Any, Callable, Self, cast -from weakref import WeakValueDictionary +from typing import Any, AsyncIterator, Callable, Self, cast from openai.types.chat import ( ChatCompletionAssistantMessageParam, @@ -522,10 +521,7 @@ async def upsert_chat_session( callers are aware of the persistence failure. RedisError: If the cache write fails (after successful DB write). """ - # Acquire session-specific lock to prevent concurrent upserts - lock = await _get_session_lock(session.session_id) - - async with lock: + async with _get_session_lock(session.session_id) as _: # Always query DB for existing message count to ensure consistency existing_message_count = await chat_db().get_next_sequence(session.session_id) @@ -651,20 +647,50 @@ async def _save_session_to_db( msg.sequence = existing_message_count + i -async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: +async def append_and_save_message( + session_id: str, message: ChatMessage +) -> ChatSession | None: """Atomically append a message to a session and persist it. - Acquires the session lock, re-fetches the latest session state, - appends the message, and saves — preventing message loss when - concurrent requests modify the same session. - """ - lock = await _get_session_lock(session_id) + Returns the updated session, or None if the message was detected as a + duplicate (idempotency guard). Callers must check for None and skip any + downstream work (e.g. enqueuing a new LLM turn) when a duplicate is detected. - async with lock: - session = await get_chat_session(session_id) + Uses _get_session_lock (Redis NX) to serialise concurrent writers across replicas. + The idempotency check below provides a last-resort guard when the lock degrades. + """ + async with _get_session_lock(session_id) as lock_acquired: + # When the lock degraded (Redis down or 2s timeout), bypass cache for + # the idempotency check. Stale cache could let two concurrent writers + # both see the old state, pass the check, and write the same message. + if lock_acquired: + session = await get_chat_session(session_id) + else: + session = await _get_session_from_db(session_id) if session is None: raise ValueError(f"Session {session_id} not found") + # Idempotency: skip if the trailing block of same-role messages already + # contains this content. Uses is_message_duplicate which checks all + # consecutive trailing messages of the same role, not just [-1]. + # + # This collapses infra/nginx retries whether they land on the same pod + # (serialised by the Redis lock) or a different pod. + # + # Legit same-text messages are distinguished by the assistant turn + # between them: if the user said "yes", got a response, and says + # "yes" again, session.messages[-1] is the assistant reply, so the + # role check fails and the second message goes through normally. + # + # Edge case: if a turn dies without writing any assistant message, + # the user's next send of the same text is blocked here permanently. + # The fix is to ensure failed turns always write an error/timeout + # assistant message so the session always ends on an assistant turn. + if message.content is not None and is_message_duplicate( + session.messages, message.role, message.content + ): + return None # duplicate — caller should skip enqueue + session.messages.append(message) existing_message_count = await chat_db().get_next_sequence(session_id) @@ -679,6 +705,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat await cache_chat_session(session) except Exception as e: logger.warning(f"Cache write failed for session {session_id}: {e}") + # Invalidate the stale entry so future reads fall back to DB, + # preventing a retry from bypassing the idempotency check above. + await invalidate_session_cache(session_id) return session @@ -699,9 +728,7 @@ async def append_message_if( Returns the updated session on append, or ``None`` if the predicate rejected, the session no longer exists, or the append failed. """ - lock = await _get_session_lock(session_id) - - async with lock: + async with _get_session_lock(session_id) as _lock_acquired: # Read from DB directly — the Redis cache can be stale because the # executor's upsert_chat_session overwrites it with in-memory copies # during streaming, which may not include messages appended by the @@ -815,10 +842,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo except Exception as e: logger.warning(f"Failed to delete session {session_id} from cache: {e}") - # Clean up session lock (belt-and-suspenders with WeakValueDictionary) - async with _session_locks_mutex: - _session_locks.pop(session_id, None) - # Shut down any local browser daemon for this session (best-effort). # Inline import required: all tool modules import ChatSession from this # module, so any top-level import from tools.* would create a cycle. @@ -883,25 +906,38 @@ async def update_session_title( # ==================== Chat session locks ==================== # -_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() -_session_locks_mutex = asyncio.Lock() +@asynccontextmanager +async def _get_session_lock(session_id: str) -> AsyncIterator[bool]: + """Distributed Redis lock for a session, usable as an async context manager. -async def _get_session_lock(session_id: str) -> asyncio.Lock: - """Get or create a lock for a specific session to prevent concurrent upserts. + Yields True if the lock was acquired, False if it timed out or Redis was + unavailable. Callers should treat False as a degraded mode and prefer fresh + DB reads over cache to avoid acting on stale state. - This was originally added to solve the specific problem of race conditions between - the session title thread and the conversation thread, which always occurs on the - same instance as we prevent rapid request sends on the frontend. - - Uses WeakValueDictionary for automatic cleanup: locks are garbage collected - when no coroutine holds a reference to them, preventing memory leaks from - unbounded growth of session locks. Explicit cleanup also occurs - in `delete_chat_session()`. + Uses redis-py's built-in Lock (Lua-script acquire/release) so lock acquisition + is atomic and release is owner-verified. Blocks up to 2s for a concurrent + writer to finish; the 10s TTL ensures a dead pod never holds the lock forever. """ - async with _session_locks_mutex: - lock = _session_locks.get(session_id) - if lock is None: - lock = asyncio.Lock() - _session_locks[session_id] = lock - return lock + _lock_key = f"copilot:session_lock:{session_id}" + lock = None + acquired = False + try: + _redis = await get_redis_async() + lock = _redis.lock(_lock_key, timeout=10, blocking_timeout=2) + acquired = await lock.acquire(blocking=True) + if not acquired: + logger.warning( + "Could not acquire session lock for %s within 2s", session_id + ) + except Exception as e: + logger.warning("Redis unavailable for session lock on %s: %s", session_id, e) + + try: + yield acquired + finally: + if acquired and lock is not None: + try: + await lock.release() + except Exception: + pass # TTL will expire the key diff --git a/autogpt_platform/backend/backend/copilot/model_test.py b/autogpt_platform/backend/backend/copilot/model_test.py index c78d63cc5a..e97ac24d51 100644 --- a/autogpt_platform/backend/backend/copilot/model_test.py +++ b/autogpt_platform/backend/backend/copilot/model_test.py @@ -11,11 +11,13 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( ChatCompletionMessageToolCallParam, Function, ) +from pytest_mock import MockerFixture from .model import ( ChatMessage, ChatSession, Usage, + append_and_save_message, get_chat_session, is_message_duplicate, maybe_append_user_message, @@ -574,3 +576,345 @@ def test_maybe_append_assistant_skips_duplicate(): result = maybe_append_user_message(session, "dup", is_user_message=False) assert result is False assert len(session.messages) == 2 + + +# --------------------------------------------------------------------------- # +# append_and_save_message # +# --------------------------------------------------------------------------- # + + +def _make_session_with_messages(*msgs: ChatMessage) -> ChatSession: + s = ChatSession.new(user_id="u1", dry_run=False) + s.messages = list(msgs) + return s + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_returns_none_for_duplicate( + mocker: MockerFixture, +) -> None: + """append_and_save_message returns None when the trailing message is a duplicate.""" + + session = _make_session_with_messages( + ChatMessage(role="user", content="hello"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock() + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=session, + ) + + result = await append_and_save_message( + session.session_id, ChatMessage(role="user", content="hello") + ) + assert result is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_appends_new_message( + mocker: MockerFixture, +) -> None: + """append_and_save_message appends a non-duplicate message and returns the session.""" + + session = _make_session_with_messages( + ChatMessage(role="user", content="hello"), + ChatMessage(role="assistant", content="hi"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock() + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=2) + ), + ) + mocker.patch( + "backend.copilot.model.cache_chat_session", + new_callable=mocker.AsyncMock, + ) + + new_msg = ChatMessage(role="user", content="second message") + result = await append_and_save_message(session.session_id, new_msg) + assert result is not None + assert result.messages[-1].content == "second message" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_raises_when_session_not_found( + mocker: MockerFixture, +) -> None: + """append_and_save_message raises ValueError when the session does not exist.""" + + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock() + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=None, + ) + + with pytest.raises(ValueError, match="not found"): + await append_and_save_message( + "missing-session-id", ChatMessage(role="user", content="hi") + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_uses_db_when_lock_degraded( + mocker: MockerFixture, +) -> None: + """When the Redis lock times out (acquired=False), the fallback reads from DB.""" + + session = _make_session_with_messages( + ChatMessage(role="assistant", content="hi"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=False) + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mock_get_from_db = mocker.patch( + "backend.copilot.model._get_session_from_db", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=1) + ), + ) + mocker.patch( + "backend.copilot.model.cache_chat_session", + new_callable=mocker.AsyncMock, + ) + + new_msg = ChatMessage(role="user", content="new msg") + result = await append_and_save_message(session.session_id, new_msg) + # DB path was used (not cache-first) + mock_get_from_db.assert_called_once_with(session.session_id) + assert result is not None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_raises_database_error_on_save_failure( + mocker: MockerFixture, +) -> None: + """When _save_session_to_db fails, append_and_save_message raises DatabaseError.""" + from backend.util.exceptions import DatabaseError + + session = _make_session_with_messages( + ChatMessage(role="assistant", content="hi"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock() + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("db down"), + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=1) + ), + ) + + with pytest.raises(DatabaseError): + await append_and_save_message( + session.session_id, ChatMessage(role="user", content="new msg") + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_invalidates_cache_on_cache_failure( + mocker: MockerFixture, +) -> None: + """When cache_chat_session fails, invalidate_session_cache is called to avoid stale reads.""" + + session = _make_session_with_messages( + ChatMessage(role="assistant", content="hi"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock() + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=1) + ), + ) + mocker.patch( + "backend.copilot.model.cache_chat_session", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("redis write failed"), + ) + mock_invalidate = mocker.patch( + "backend.copilot.model.invalidate_session_cache", + new_callable=mocker.AsyncMock, + ) + + result = await append_and_save_message( + session.session_id, ChatMessage(role="user", content="new msg") + ) + # DB write succeeded, cache invalidation was called + mock_invalidate.assert_called_once_with(session.session_id) + assert result is not None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_uses_db_when_redis_unavailable( + mocker: MockerFixture, +) -> None: + """When get_redis_async raises, _get_session_lock yields False (degraded) and DB is read.""" + + session = _make_session_with_messages( + ChatMessage(role="assistant", content="hi"), + ) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + side_effect=ConnectionError("redis down"), + ) + mock_get_from_db = mocker.patch( + "backend.copilot.model._get_session_from_db", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=1) + ), + ) + mocker.patch( + "backend.copilot.model.cache_chat_session", + new_callable=mocker.AsyncMock, + ) + + new_msg = ChatMessage(role="user", content="new msg") + result = await append_and_save_message(session.session_id, new_msg) + mock_get_from_db.assert_called_once_with(session.session_id) + assert result is not None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_append_and_save_message_lock_release_failure_is_ignored( + mocker: MockerFixture, +) -> None: + """If lock.release() raises, the exception is swallowed (TTL will clean up).""" + + session = _make_session_with_messages( + ChatMessage(role="assistant", content="hi"), + ) + mock_redis_lock = mocker.AsyncMock() + mock_redis_lock.acquire = mocker.AsyncMock(return_value=True) + mock_redis_lock.release = mocker.AsyncMock( + side_effect=RuntimeError("release failed") + ) + mock_redis_client = mocker.MagicMock() + mock_redis_client.lock = mocker.MagicMock(return_value=mock_redis_lock) + mocker.patch( + "backend.copilot.model.get_redis_async", + new_callable=mocker.AsyncMock, + return_value=mock_redis_client, + ) + mocker.patch( + "backend.copilot.model.get_chat_session", + new_callable=mocker.AsyncMock, + return_value=session, + ) + mocker.patch( + "backend.copilot.model._save_session_to_db", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "backend.copilot.model.chat_db", + return_value=mocker.MagicMock( + get_next_sequence=mocker.AsyncMock(return_value=1) + ), + ) + mocker.patch( + "backend.copilot.model.cache_chat_session", + new_callable=mocker.AsyncMock, + ) + + new_msg = ChatMessage(role="user", content="new msg") + result = await append_and_save_message(session.session_id, new_msg) + assert result is not None diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index ec136933e9..ed436733dd 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -174,6 +174,7 @@ sandbox so `bash_exec` can access it for further processing. The exact sandbox path is shown in the `[Sandbox copy available at ...]` note. ### GitHub CLI (`gh`) and git +- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it. - If the user has connected their GitHub account, both `gh` and `git` are pre-authenticated — use them directly without any manual login step. `git` HTTPS operations (clone, push, pull) work automatically. diff --git a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py index 5e1ef41979..212fca189b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/mode_switch_context_test.py @@ -8,7 +8,7 @@ Cross-mode transcript flow ========================== Both ``baseline/service.py`` (fast mode) and ``sdk/service.py`` (extended_thinking -mode) read and write the same JSONL transcript store via +mode) read and write the same CLI session store via ``backend.copilot.transcript.upload_transcript`` / ``download_transcript``. @@ -250,8 +250,9 @@ class TestSdkToFastModeSwitch: @pytest.mark.asyncio async def test_scenario_s_baseline_loads_sdk_transcript(self): - """Scenario S: SDK-written transcript is accepted by baseline's load helper.""" + """Scenario S: SDK-written CLI session is accepted by baseline's load helper.""" from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -267,33 +268,41 @@ class TestSdkToFastModeSwitch: sdk_transcript = builder_sdk.to_jsonl() # Baseline session now has those 2 SDK messages + 1 new baseline message. - download = TranscriptDownload(content=sdk_transcript, message_count=2) + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=3, # 2 SDK + 1 new baseline + session_messages=[ + ChatMessage(role="user", content="sdk-question"), + ChatMessage(role="assistant", content="sdk-answer"), + ChatMessage(role="user", content="baseline-question"), + ], transcript_builder=baseline_builder, ) - # Transcript is valid and covers the prefix. + # CLI session is valid and covers the prefix. assert covers is True + assert dl is not None assert baseline_builder.entry_count == 2 @pytest.mark.asyncio async def test_scenario_s_stale_sdk_transcript_not_loaded(self): - """Scenario S (stale): SDK transcript is stale — baseline does not load it. + """Scenario S (stale): SDK CLI session is stale — baseline does not load it. - If SDK mode produced more turns than the transcript captured (e.g. - upload failed on one turn), the baseline rejects the stale transcript + If SDK mode produced more turns than the session captured (e.g. + upload failed on one turn), the baseline rejects the stale session to avoid injecting an incomplete history. """ from backend.copilot.baseline.service import _load_prior_transcript + from backend.copilot.model import ChatMessage from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload from backend.copilot.transcript_builder import TranscriptBuilder @@ -306,21 +315,33 @@ class TestSdkToFastModeSwitch: ) sdk_transcript = builder_sdk.to_jsonl() - # Transcript covers only 2 messages but session has 10 (many SDK turns). - download = TranscriptDownload(content=sdk_transcript, message_count=2) + # Session covers only 2 messages but session has 10 (many SDK turns). + # With watermark=2 and 10 total messages, detect_gap will fill the gap + # by appending messages 2..8 (positions 2 to total-2). + restore = TranscriptDownload( + content=sdk_transcript.encode("utf-8"), message_count=2, mode="sdk" + ) + + # Build a session with 10 alternating user/assistant messages + current user + session_messages = [ + ChatMessage(role="user" if i % 2 == 0 else "assistant", content=f"msg-{i}") + for i in range(10) + ] baseline_builder = TranscriptBuilder() with patch( "backend.copilot.baseline.service.download_transcript", - new=AsyncMock(return_value=download), + new=AsyncMock(return_value=restore), ): - covers = await _load_prior_transcript( + covers, dl = await _load_prior_transcript( user_id="user-1", session_id="session-1", - session_msg_count=10, + session_messages=session_messages, transcript_builder=baseline_builder, ) - # Stale transcript must be rejected. - assert covers is False - assert baseline_builder.is_empty + # With gap filling, covers is True and gap messages are appended. + assert covers is True + assert dl is not None + # 2 from transcript + 7 gap messages (positions 2..8, excluding last user turn) + assert baseline_builder.entry_count == 9 diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index a48d7def3d..60c65f00ce 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -27,6 +27,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.copilot.transcript import ( + TranscriptDownload, _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, @@ -999,14 +1000,15 @@ def _make_sdk_patches( f"{_SVC}.download_transcript", dict( new_callable=AsyncMock, - return_value=MagicMock(content=original_transcript, message_count=2), + return_value=TranscriptDownload( + content=original_transcript.encode("utf-8"), + message_count=2, + mode="sdk", + ), ), ), - ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=True), - ), - (f"{_SVC}.upload_cli_session", dict(new_callable=AsyncMock)), + (f"{_SVC}.strip_for_upload", dict(return_value=original_transcript)), + (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.validate_transcript", dict(return_value=True)), ( f"{_SVC}.compact_transcript", @@ -1037,7 +1039,6 @@ def _make_sdk_patches( claude_agent_fallback_model=None, ), ), - (f"{_SVC}.upload_transcript", dict(new_callable=AsyncMock)), (f"{_SVC}.get_user_tier", dict(new_callable=AsyncMock, return_value=None)), ] @@ -1914,14 +1915,14 @@ class TestStreamChatCompletionRetryIntegration: compacted_transcript=None, client_side_effect=_client_factory, ) - # Override restore_cli_session to return False (CLI native session unavailable) + # Override download_transcript to return None (CLI native session unavailable) patches = [ ( ( - f"{_SVC}.restore_cli_session", - dict(new_callable=AsyncMock, return_value=False), + f"{_SVC}.download_transcript", + dict(new_callable=AsyncMock, return_value=None), ) - if p[0] == f"{_SVC}.restore_cli_session" + if p[0] == f"{_SVC}.download_transcript" else p ) for p in patches @@ -1944,7 +1945,7 @@ class TestStreamChatCompletionRetryIntegration: # captured_options holds {"options": ClaudeAgentOptions}, so check # the attribute directly rather than dict keys. assert not getattr(captured_options.get("options"), "resume", None), ( - f"--resume was set even though restore_cli_session returned False: " + f"--resume was set even though download_transcript returned None: " f"{captured_options}" ) assert any(isinstance(e, StreamStart) for e in events) diff --git a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py index e5ba184f4f..666e55fbba 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py @@ -365,7 +365,7 @@ def create_security_hooks( trigger = _sanitize(str(input_data.get("trigger", "auto")), max_len=50) # Sanitize untrusted input: strip control chars for logging AND # for the value passed downstream. read_compacted_entries() - # validates against _projects_base() as defence-in-depth, but + # validates against projects_base() as defence-in-depth, but # sanitizing here prevents log injection and rejects obviously # malformed paths early. transcript_path = _sanitize( diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index ed27b7c134..9cef40ba7a 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -16,6 +16,7 @@ import uuid from collections.abc import AsyncGenerator, AsyncIterator from dataclasses import dataclass from dataclasses import field as dataclass_field +from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple, NotRequired, cast if TYPE_CHECKING: @@ -92,12 +93,15 @@ from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path from ..tracking import track_user_message from ..transcript import ( _run_compression, + TranscriptDownload, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, - upload_cli_session, + strip_for_upload, upload_transcript, validate_transcript, ) @@ -121,7 +125,12 @@ config = ChatConfig() class _SystemPromptPreset(SystemPromptPreset, total=False): - """Extends SystemPromptPreset with fields added in claude-agent-sdk 0.1.59.""" + """Extends :class:`SystemPromptPreset` with ``exclude_dynamic_sections``. + + The field was added to the upstream TypedDict in claude-agent-sdk 0.1.59. + Until the package is pinned to that version we declare it locally so Pyright + accepts the kwarg without a ``# type: ignore`` comment. + """ exclude_dynamic_sections: NotRequired[bool] @@ -849,6 +858,181 @@ def _make_sdk_cwd(session_id: str) -> str: return cwd +def _write_cli_session_to_disk( + content: bytes, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bool: + """Write downloaded CLI session bytes to disk so the CLI can --resume. + + Returns True on success, False if the path is invalid or the write fails. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session restore path outside projects base: %s", + log_prefix, + os.path.basename(session_file), + ) + return False + try: + os.makedirs(os.path.dirname(real_path), exist_ok=True) + Path(real_path).write_bytes(content) + logger.info( + "%s Wrote CLI session to disk (%dB) for --resume", + log_prefix, + len(content), + ) + return True + except OSError as e: + logger.warning( + "%s Failed to write CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return False + + +def read_cli_session_from_disk( + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> bytes | None: + """Read the CLI session JSONL file from disk after the SDK turn. + + Returns the file bytes, or None if the file is missing, outside the + projects base, or unreadable. + Path-traversal guard: rejects paths outside the CLI projects base. + """ + session_file = cli_session_path(sdk_cwd, session_id) + real_path = os.path.realpath(session_file) + _pbase = projects_base() + if not real_path.startswith(_pbase + os.sep): + logger.warning( + "%s CLI session file outside projects base, skipping upload: %s", + log_prefix, + os.path.basename(real_path), + ) + return None + try: + raw_bytes = Path(real_path).read_bytes() + except FileNotFoundError: + logger.debug( + "%s CLI session file not found, skipping upload: %s", + log_prefix, + os.path.basename(session_file), + ) + return None + except OSError as e: + logger.warning( + "%s Failed to read CLI session file %s: %s", + log_prefix, + os.path.basename(session_file), + e.strerror or str(e), + ) + return None + + # Strip stale thinking blocks and metadata entries before uploading. + # Thinking blocks from non-last turns can be massive; keeping them causes + # the CLI to auto-compact its session when the context window fills up, + # silently losing conversation history. + try: + raw_text = raw_bytes.decode("utf-8") + stripped_text = strip_for_upload(raw_text) + stripped_bytes = stripped_text.encode("utf-8") + except UnicodeDecodeError: + logger.warning("%s CLI session is not valid UTF-8, uploading raw", log_prefix) + return raw_bytes + except (OSError, ValueError) as e: + # OSError: encode/decode I/O failure; ValueError: malformed JSONL in strip. + # Other unexpected exceptions are not silently swallowed here so they propagate + # to the outer OSError handler and are logged with exc_info. + logger.warning( + "%s Failed to strip CLI session, uploading raw: %s", log_prefix, e + ) + return raw_bytes + + if len(stripped_bytes) < len(raw_bytes): + # Write back locally so same-pod turns also benefit. + try: + Path(real_path).write_bytes(stripped_bytes) + logger.info( + "%s Stripped CLI session: %dB → %dB", + log_prefix, + len(raw_bytes), + len(stripped_bytes), + ) + except OSError as e: + # write_bytes failed — stripped content is still valid for GCS upload even + # though the local write-back failed (same-pod optimization silently skipped). + logger.warning( + "%s Failed to write back stripped CLI session: %s", + log_prefix, + e.strerror or str(e), + ) + return stripped_bytes + + +def process_cli_restore( + cli_restore: TranscriptDownload, + sdk_cwd: str, + session_id: str, + log_prefix: str, +) -> tuple[str, bool]: + """Validate and write a restored CLI session to disk. + + Decodes bytes → UTF-8, strips progress entries and stale thinking blocks, + validates the result, then writes the stripped content to disk so the CLI + can ``--resume`` from it. + + Returns ``(stripped_content, success)`` where ``success=False`` means the + content was invalid or the disk write failed (caller should skip --resume). + """ + try: + raw_bytes = cli_restore.content + raw_str = ( + raw_bytes.decode("utf-8") if isinstance(raw_bytes, bytes) else raw_bytes + ) + except UnicodeDecodeError: + logger.warning( + "%s CLI session content is not valid UTF-8, skipping", log_prefix + ) + return "", False + + stripped = strip_for_upload(raw_str) + is_valid = validate_transcript(stripped) + # Use len(raw_str) rather than len(cli_restore.content) so the unit is always + # characters (raw_str is always str at this point regardless of input type). + # lines_stripped = original lines minus remaining lines after stripping. + _original_lines = len(raw_str.strip().split("\n")) if raw_str.strip() else 0 + _remaining_lines = len(stripped.strip().split("\n")) if stripped.strip() else 0 + logger.info( + "%s Restored CLI session: %dB raw, %d lines stripped, msg_count=%d, valid=%s", + log_prefix, + len(raw_str), + _original_lines - _remaining_lines, + cli_restore.message_count, + is_valid, + ) + if not is_valid: + logger.warning( + "%s CLI session content invalid after strip — running without --resume", + log_prefix, + ) + return "", False + + stripped_bytes = stripped.encode("utf-8") + if not _write_cli_session_to_disk(stripped_bytes, sdk_cwd, session_id, log_prefix): + return "", False + + return stripped, True + + async def _cleanup_sdk_tool_results(cwd: str) -> None: """Remove SDK session artifacts for a specific working directory. @@ -922,8 +1106,9 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]: result.append(block) else: logger.warning( - f"[SDK] Unknown content block type: {type(block).__name__}. " - f"This may indicate a new SDK version with additional block types." + "[SDK] Unknown content block type: %s." + " This may indicate a new SDK version with additional block types.", + type(block).__name__, ) return result @@ -978,10 +1163,11 @@ async def _compress_messages( if result.was_compacted: logger.info( - f"[SDK] Context compacted: {result.original_token_count} -> " - f"{result.token_count} tokens " - f"({result.messages_summarized} summarized, " - f"{result.messages_dropped} dropped)" + "[SDK] Context compacted: %d -> %d tokens (%d summarized, %d dropped)", + result.original_token_count, + result.token_count, + result.messages_summarized, + result.messages_dropped, ) # Convert compressed dicts back to ChatMessages return [ @@ -1048,11 +1234,17 @@ def _session_messages_to_transcript(messages: list[ChatMessage]) -> str: ) if blocks: builder.append_assistant(blocks) - elif msg.role == "tool" and msg.tool_call_id: - builder.append_tool_result( - tool_use_id=msg.tool_call_id, - content=msg.content or "", - ) + elif msg.role == "tool": + if msg.tool_call_id: + builder.append_tool_result( + tool_use_id=msg.tool_call_id, + content=msg.content or "", + ) + else: + # Malformed tool message — no tool_call_id to link to an + # assistant tool_use block. Skip to avoid an unmatched + # tool_result entry in the builder (which would confuse --resume). + logger.warning("[SDK] Skipping tool gap message with no tool_call_id") return builder.to_jsonl() @@ -1098,6 +1290,7 @@ async def _build_query_message( transcript_msg_count: int, session_id: str, target_tokens: int | None = None, + prior_messages: "list[ChatMessage] | None" = None, ) -> tuple[str, bool]: """Build the query message with appropriate context. @@ -1203,15 +1396,16 @@ async def _build_query_message( ) return current_message, False + source = prior_messages if prior_messages is not None else prior logger.warning( - "[SDK] [%s] No --resume for %d-message session — compressing" - " full session history (pod affinity issue or first turn after" - " restore failure); target_tokens=%s", + "[SDK] [%s] No --resume for %d-message session — compressing context " + "(source=%s, target_tokens=%s)", session_id[:8], msg_count, + "transcript+gap" if prior_messages is not None else "full-db", target_tokens, ) - compressed, was_compressed = await _compress_messages(prior, target_tokens) + compressed, was_compressed = await _compress_messages(source, target_tokens) history_context = _format_conversation_context(compressed) if history_context: logger.info( @@ -1228,7 +1422,7 @@ async def _build_query_message( "[SDK] [%s] Fallback context empty after compression" " (%d messages) — sending message without history", session_id[:8], - len(prior), + len(source), ) return current_message, False @@ -2233,6 +2427,161 @@ async def _seed_transcript( return _seeded, True, len(_prior) +@dataclass +class _RestoreResult: + """Return value from ``_restore_cli_session_for_turn``.""" + + transcript_content: str = "" + transcript_covers_prefix: bool = True + use_resume: bool = False + resume_file: str | None = None + transcript_msg_count: int = 0 + baseline_download: "TranscriptDownload | None" = None + context_messages: "list[ChatMessage] | None" = None + + +async def _restore_cli_session_for_turn( + user_id: str | None, + session_id: str, + session: "ChatSession", + sdk_cwd: str, + transcript_builder: "TranscriptBuilder", + log_prefix: str, +) -> _RestoreResult: + """Download, validate and restore a CLI session for ``--resume`` on this turn. + + Performs a single GCS round-trip to fetch the session bytes + message_count + watermark. Falls back to DB-message reconstruction when GCS has no session + (first turn or upload missed). + + Returns a ``_RestoreResult`` with all transcript-related state ready for the + caller to merge into its local variables. + """ + result = _RestoreResult() + + if not (config.claude_agent_use_resume and user_id and len(session.messages) > 1): + return result + + try: + cli_restore = await download_transcript( + user_id, session_id, log_prefix=log_prefix + ) + except Exception as restore_err: + logger.warning( + "%s CLI session restore failed, continuing without --resume: %s", + log_prefix, + restore_err, + ) + cli_restore = None + + # Only attempt --resume for SDK-written transcripts. + # Baseline-written transcripts use TranscriptBuilder format (synthetic IDs, + # stripped fields) that may not be valid for --resume. + if cli_restore is not None and cli_restore.mode != "sdk": + logger.info( + "%s Transcript written by mode=%r — skipping --resume, " + "will use transcript content + gap for context", + log_prefix, + cli_restore.mode, + ) + result.baseline_download = cli_restore # keep for extract_context_messages + cli_restore = None + + # Validate, strip, and write to disk — delegate to helper to reduce + # function complexity. Writing an invalid/corrupt file to disk then + # falling back to "no --resume" would cause the CLI to fail with + # "Session ID already in use" because the file exists at the expected + # session path, so we validate BEFORE any disk write. + stripped = "" + if cli_restore is not None and sdk_cwd: + stripped, ok = process_cli_restore(cli_restore, sdk_cwd, session_id, log_prefix) + if not ok: + result.transcript_covers_prefix = False + cli_restore = None + + if cli_restore is None and sdk_cwd: + # Validation failed or GCS returned no session. Delete any + # existing local session file so the CLI doesn't reject the + # session_id with "Session ID already in use". T1 may have + # left a valid file at this path; we clear it so the fallback + # path (session_id= without --resume) can create a new session. + _stale_path = os.path.realpath(cli_session_path(sdk_cwd, session_id)) + if Path(_stale_path).exists() and _stale_path.startswith( + projects_base() + os.sep + ): + try: + Path(_stale_path).unlink() + logger.debug( + "%s Removed stale local CLI session file for clean fallback", + log_prefix, + ) + except OSError as _unlink_err: + logger.debug( + "%s Failed to remove stale local session file: %s", + log_prefix, + _unlink_err, + ) + + if cli_restore is not None: + result.transcript_content = stripped + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.use_resume = True + result.resume_file = session_id + result.transcript_msg_count = cli_restore.message_count + return result + + # No valid --resume source (mode="baseline" or no GCS file). + # Build context from transcript content + gap, falling back to full DB. + # extract_context_messages handles both: non-None baseline_download uses + # the compacted transcript + gap; None falls back to all prior DB messages. + context_msgs = extract_context_messages(result.baseline_download, session.messages) + result.context_messages = context_msgs + result.transcript_msg_count = ( + result.baseline_download.message_count + if result.baseline_download is not None + and result.baseline_download.message_count > 0 + else len(session.messages) - 1 + ) + result.transcript_covers_prefix = True + logger.info( + "%s Context built from %s: %d messages (transcript watermark=%d, " + "will inject as )", + log_prefix, + ( + "baseline transcript + gap" + if result.baseline_download is not None + else "DB fallback" + ), + len(context_msgs), + result.transcript_msg_count, + ) + + # Load baseline transcript content into builder so the upload path has accurate state. + # Also sets result.transcript_content so the _seed_transcript guard in the caller + # (``not transcript_content``) does not overwrite this builder state with a DB + # reconstruction — which would duplicate entries since load_previous appends. + if result.baseline_download is not None: + try: + raw_for_builder = result.baseline_download.content + if isinstance(raw_for_builder, bytes): + raw_for_builder = raw_for_builder.decode("utf-8") + stripped = strip_for_upload(raw_for_builder) + if validate_transcript(stripped): + transcript_builder.load_previous(stripped, log_prefix=log_prefix) + result.transcript_content = stripped + except (UnicodeDecodeError, ValueError, OSError) as _load_err: + # UnicodeDecodeError: non-UTF-8 content; ValueError: malformed JSONL in + # strip_for_upload; OSError: encode/decode I/O failure. Unexpected + # exceptions propagate so programming errors are not silently masked. + logger.debug( + "%s Could not load baseline transcript into builder: %s", + log_prefix, + _load_err, + ) + + return result + + async def stream_chat_completion_sdk( session_id: str, message: str | None = None, @@ -2427,28 +2776,9 @@ async def stream_chat_completion_sdk( return sandbox - async def _fetch_transcript(): - """Download transcript for --resume if applicable.""" - if not ( - config.claude_agent_use_resume and user_id and len(session.messages) > 1 - ): - return None - try: - return await download_transcript( - user_id, session_id, log_prefix=log_prefix - ) - except Exception as transcript_err: - logger.warning( - "%s Transcript download failed, continuing without --resume: %s", - log_prefix, - transcript_err, - ) - return None - - e2b_sandbox, (base_system_prompt, understanding), dl = await asyncio.gather( + e2b_sandbox, (base_system_prompt, understanding) = await asyncio.gather( _setup_e2b(), _build_system_prompt(user_id if not has_history else None), - _fetch_transcript(), ) use_e2b = e2b_sandbox is not None @@ -2473,95 +2803,17 @@ async def stream_chat_completion_sdk( warm_ctx = await fetch_warm_context(user_id, message or "") or "" - # Process transcript download result and restore CLI native session. - # The CLI native session file (uploaded after each turn) is the - # source of truth for --resume. Our custom JSONL (TranscriptEntry) - # is loaded into the builder for future upload_transcript calls. - transcript_msg_count = 0 - if dl: - is_valid = validate_transcript(dl.content) - dl_lines = dl.content.strip().split("\n") if dl.content else [] - logger.info( - "%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s", - log_prefix, - len(dl.content), - len(dl_lines), - dl.message_count, - is_valid, - ) - if is_valid: - # Load previous FULL context into builder for state tracking. - transcript_content = dl.content - transcript_builder.load_previous(dl.content, log_prefix=log_prefix) - # Restore CLI's native session file so --resume session_id works. - # Falls back gracefully if not available (first turn or upload missed). - # user_id is guaranteed non-None here: _fetch_transcript only sets dl - # when `config.claude_agent_use_resume and user_id` is truthy. - cli_restored = user_id is not None and await restore_cli_session( - user_id, session_id, sdk_cwd, log_prefix=log_prefix - ) - if cli_restored: - use_resume = True - resume_file = session_id # CLI --resume expects UUID, not file path - transcript_msg_count = dl.message_count - logger.info( - "%s Using --resume %s (%dB transcript, msg_count=%d)", - log_prefix, - session_id[:8], - len(dl.content), - transcript_msg_count, - ) - else: - # Builder loaded but CLI native session not available. - # --resume will not be used this turn; upload after turn - # will seed the native session for the next turn. - # - # Still record transcript_msg_count so _build_query_message - # can use the transcript-aware gap path (inject only new - # messages since the transcript end) instead of compressing - # the full DB history. This avoids prompt-too-long on - # large sessions where the CLI session is temporarily - # unavailable (e.g. mixed-version rolling deployment). - transcript_msg_count = dl.message_count - logger.info( - "%s CLI session not restored — running without" - " --resume this turn (transcript_msg_count=%d for" - " gap-aware fallback)", - log_prefix, - transcript_msg_count, - ) - else: - logger.warning("%s Transcript downloaded but invalid", log_prefix) - transcript_covers_prefix = False - elif config.claude_agent_use_resume and user_id and len(session.messages) > 1: - # No transcript in storage — reconstruct from DB messages as a - # last-resort fallback (e.g., first turn after a crash or transition). - # This path loses tool call IDs and structural fidelity but prevents - # a completely context-free response for established sessions. - prior = session.messages[:-1] - reconstructed = _session_messages_to_transcript(prior) - if reconstructed: - # Populate builder only; no --resume since there is no CLI - # native session to restore. The transcript builder state is - # still useful for the upload that seeds future native sessions. - transcript_content = reconstructed - transcript_builder.load_previous(reconstructed, log_prefix=log_prefix) - transcript_msg_count = len(prior) - transcript_covers_prefix = True - logger.info( - "%s Reconstructed transcript from %d session messages " - "(no CLI native session — running without --resume this turn)", - log_prefix, - len(prior), - ) - else: - logger.warning( - "%s No transcript available and reconstruction produced empty" - " output (%d messages in session)", - log_prefix, - len(session.messages), - ) - transcript_covers_prefix = False + # Restore CLI session — single GCS round-trip covers both --resume and builder state. + # message_count watermark lives in the companion .meta.json alongside the session file. + _restore = await _restore_cli_session_for_turn( + user_id, session_id, session, sdk_cwd, transcript_builder, log_prefix + ) + transcript_content = _restore.transcript_content + transcript_covers_prefix = _restore.transcript_covers_prefix + use_resume = _restore.use_resume + resume_file = _restore.resume_file + transcript_msg_count = _restore.transcript_msg_count + restore_context_messages = _restore.context_messages yield StreamStart(messageId=message_id, sessionId=session_id) @@ -2680,14 +2932,14 @@ async def stream_chat_completion_sdk( else: # Set session_id whenever NOT resuming so the CLI writes the # native session file to a predictable path for - # upload_cli_session() after the turn. This covers: + # upload_transcript() after the turn. This covers: # • T1 fresh: no prior history, first SDK turn. # • Mode-switch T1: has_history=True (prior baseline turns in # DB) but no CLI session file was ever uploaded — the CLI has # never been invoked with this session_id before. # • T2+ without --resume (restore failed): no session file was - # restored to local storage (restore_cli_session returned - # False), so no conflict with an existing file. + # restored to local storage (download_transcript returned + # None), so no conflict with an existing file. # When --resume is active the session_id is already implied by # the resume file; passing it again would be rejected by the CLI. sdk_options_kwargs["session_id"] = session_id @@ -2780,6 +3032,7 @@ async def stream_chat_completion_sdk( use_resume, transcript_msg_count, session_id, + prior_messages=restore_context_messages, ) # If files are attached, prepare them: images become vision # content blocks in the user message, other files go to sdk_cwd. @@ -2909,7 +3162,7 @@ async def stream_chat_completion_sdk( elif "session_id" in sdk_options_kwargs: # Initial invocation used session_id (T1 or mode-switch # T1): keep it so the CLI writes the session file to the - # predictable path for upload_cli_session(). Storage is + # predictable path for upload_transcript(). Storage is # ephemeral per invocation, so no "Session ID already in # use" conflict occurs — no prior file was restored. sdk_options_kwargs_retry.pop("resume", None) @@ -2932,6 +3185,10 @@ async def stream_chat_completion_sdk( system_prompt, cross_user_cache=_cross_user_retry ) state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs + # Retry intentionally omits prior_messages (transcript+gap context) and + # falls back to full session.messages[:-1] from DB — the authoritative + # source. transcript+gap is an optimisation for the first attempt only; + # on retry the extra overhead of full-DB context is acceptable. state.query_message, state.was_compacted = await _build_query_message( current_message, session, @@ -3367,86 +3624,23 @@ async def stream_chat_completion_sdk( _background_tasks.add(_ingest_task) _ingest_task.add_done_callback(_background_tasks.discard) - # --- Upload transcript for next-turn --resume --- - # TranscriptBuilder is the single source of truth. It mirrors the - # CLI's active context: on compaction, replace_entries() syncs it - # with the compacted session file. No CLI file read needed here. - if skip_transcript_upload: - logger.warning( - "%s Skipping transcript upload — transcript was dropped " - "during prompt-too-long recovery", - log_prefix, - ) - elif ( - config.claude_agent_use_resume - and user_id - and session is not None - and state is not None - ): - try: - transcript_upload_content = state.transcript_builder.to_jsonl() - entry_count = state.transcript_builder.entry_count - - if not transcript_upload_content: - logger.warning( - "%s No transcript to upload (builder empty)", log_prefix - ) - elif not validate_transcript(transcript_upload_content): - logger.warning( - "%s Transcript invalid, skipping upload (entries=%d)", - log_prefix, - entry_count, - ) - elif not transcript_covers_prefix: - logger.warning( - "%s Skipping transcript upload — builder does not " - "cover full session prefix (entries=%d, session=%d)", - log_prefix, - entry_count, - len(session.messages), - ) - else: - logger.info( - "%s Uploading transcript (entries=%d, bytes=%d)", - log_prefix, - entry_count, - len(transcript_upload_content), - ) - await asyncio.shield( - upload_transcript( - user_id=user_id, - session_id=session_id, - content=transcript_upload_content, - message_count=len(session.messages), - log_prefix=log_prefix, - ) - ) - except Exception as upload_err: - logger.error( - "%s Transcript upload failed in finally: %s", - log_prefix, - upload_err, - exc_info=True, - ) - # --- Upload CLI native session file for cross-pod --resume --- # The CLI writes its native session JSONL after each turn completes. - # Uploading it here enables --resume on any pod (no pod affinity needed). - # Runs after upload_transcript so both are available for the next turn. - # asyncio.shield: same pattern as upload_transcript above — if the - # outer finally-block coroutine is cancelled while awaiting shield, - # the CancelledError propagates (BaseException, not caught by - # `except Exception`) letting the caller handle cancellation, while - # the shielded inner coroutine continues running to completion so the - # upload is not lost. This is intentional and matches the pattern - # used for upload_transcript immediately above. + # The companion .meta.json carries the message_count watermark and mode + # so the next turn can restore both --resume context and gap-fill state + # in a single GCS round-trip via download_transcript(). + # asyncio.shield: if the outer finally-block coroutine is cancelled + # while awaiting shield, the CancelledError propagates (BaseException, + # not caught by `except Exception`) letting the caller handle + # cancellation, while the shielded inner coroutine continues running + # to completion so the upload is not lost. # # NOTE: upload is attempted regardless of state.use_resume — even when # this turn ran without --resume (restore failed or first T2+ on a new # pod), the T1 session file at the expected path may still be present # and should be re-uploaded so the next turn can resume from it. - # upload_cli_session silently skips when the file is absent, so this is - # always safe. + # read_cli_session_from_disk returns None when the file is absent, so + # this is always safe. # # Intentionally NOT gated on skip_transcript_upload: that flag is set # when our custom JSONL transcript is dropped (transcript_lost=True on @@ -3472,14 +3666,36 @@ async def stream_chat_completion_sdk( skip_transcript_upload, ) try: - await asyncio.shield( - upload_cli_session( - user_id=user_id, - session_id=session_id, - sdk_cwd=sdk_cwd, - log_prefix=log_prefix, - ) + # Read the CLI's native session file from disk (written by the CLI + # after the turn), then upload the bytes to GCS. + _cli_content = read_cli_session_from_disk( + sdk_cwd, session_id, log_prefix ) + if _cli_content: + # Watermark = number of DB messages this transcript covers. + # len(session.messages) is accurate: the CLI session file + # was just written after the turn completed, so it covers + # all messages through this turn. Any gap from a prior + # missed upload was already detected by detect_gap and + # injected as context, so the model has the full history. + # + # Previously this used _final_tmsg_count + 2, which + # under-counted for tool-use turns (delta = 2 + 2*N_tool_calls), + # causing persistent spurious gap-fills on every subsequent turn. + # That concern was addressed by the inflated-watermark fix + # (using the GCS watermark as the anchor for gap detection), + # which makes len(session.messages) safe to use here. + _jsonl_covered = len(session.messages) + await asyncio.shield( + upload_transcript( + user_id=user_id, + session_id=session_id, + content=_cli_content, + message_count=_jsonl_covered, + mode="sdk", + log_prefix=log_prefix, + ) + ) except Exception as cli_upload_err: logger.warning( "%s CLI session upload failed in finally: %s", diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 7c5e429697..3b919c6036 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -22,6 +22,7 @@ from .service import ( _iter_sdk_messages, _normalize_model_name, _reduce_context, + _restore_cli_session_for_turn, _TokenUsage, ) @@ -615,3 +616,340 @@ class TestSdkSessionIdSelection: ) assert retry.get("resume") == self.SESSION_ID assert "session_id" not in retry + + +# --------------------------------------------------------------------------- +# _restore_cli_session_for_turn — mode check +# --------------------------------------------------------------------------- + + +class TestRestoreCliSessionModeCheck: + """SDK skips --resume when the transcript was written by the baseline mode.""" + + @pytest.mark.asyncio + async def test_baseline_mode_transcript_skips_gcs_content(self, tmp_path): + """A transcript with mode='baseline' must not be used as the --resume source. + + The mode check discards the GCS baseline content and falls back to DB + reconstruction from session.messages instead. + """ + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hello-unique-marker"), + ChatMessage(role="assistant", content="world-unique-marker"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + # Baseline content with a sentinel that must NOT appear in the final transcript + baseline_restore = TranscriptDownload( + content=b'{"type":"user","uuid":"bad-uuid","message":{"role":"user","content":"BASELINE_SENTINEL"}}\n', + message_count=1, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + download_mock = AsyncMock(return_value=baseline_restore) + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=download_mock, + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + # download_transcript was called (attempted GCS restore) + download_mock.assert_awaited_once() + # use_resume must be False — baseline transcripts cannot be used with --resume + assert result.use_resume is False + # context_messages must be populated — new behaviour uses transcript content + gap + # instead of full DB reconstruction. + assert result.context_messages is not None + # The baseline transcript has 1 user message (BASELINE_SENTINEL). + # Watermark=1 but position 0 is 'user', not 'assistant', so detect_gap returns []. + # Result: 1 message from transcript, no gap. + assert len(result.context_messages) == 1 + assert "BASELINE_SENTINEL" in (result.context_messages[0].content or "") + + @pytest.mark.asyncio + async def test_sdk_mode_transcript_allows_resume(self, tmp_path): + """A valid SDK-written transcript is accepted for --resume.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "hello"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="hi"), + ChatMessage(role="assistant", content="hello"), + ChatMessage(role="user", content="follow up"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + sdk_restore = TranscriptDownload( + content=content, + message_count=2, + mode="sdk", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=sdk_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is True + + @pytest.mark.asyncio + async def test_baseline_mode_context_messages_from_transcript_content( + self, tmp_path + ): + """mode='baseline' → context_messages populated from transcript content + gap. + + When a baseline-mode transcript exists, extract_context_messages converts + the JSONL content to ChatMessage objects and returns them in context_messages. + use_resume must remain False. + """ + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Build a minimal valid JSONL transcript with 2 messages + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER"), + ChatMessage(role="assistant", content="DB_ASSISTANT"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # Transcript content has 2 messages, no gap (watermark=2, session prior=2) + assert len(result.context_messages) == 2 + assert result.context_messages[0].role == "user" + assert result.context_messages[1].role == "assistant" + assert "TRANSCRIPT_ASSISTANT" in (result.context_messages[1].content or "") + # transcript_content must be non-empty so the _seed_transcript guard in + # stream_chat_completion_sdk skips DB reconstruction (which would duplicate + # builder entries since load_previous appends). + assert result.transcript_content != "" + + @pytest.mark.asyncio + async def test_baseline_mode_gap_present_context_includes_gap(self, tmp_path): + """mode='baseline' + gap → context_messages includes transcript msgs and gap.""" + import json as stdlib_json + from datetime import UTC, datetime + + from backend.copilot.model import ChatMessage, ChatSession + from backend.copilot.transcript import STOP_REASON_END_TURN, TranscriptDownload + from backend.copilot.transcript_builder import TranscriptBuilder + + # Transcript covers only 2 messages; session has 4 prior + current turn + lines = [ + stdlib_json.dumps( + { + "type": "user", + "uuid": "uid-0", + "parentUuid": "", + "message": {"role": "user", "content": "TRANSCRIPT_USER_0"}, + } + ), + stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-0", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "TRANSCRIPT_ASSISTANT_1"}], + }, + } + ), + ] + content = ("\n".join(lines) + "\n").encode("utf-8") + + session = ChatSession( + session_id="test-session", + user_id="user-1", + messages=[ + ChatMessage(role="user", content="DB_USER_0"), + ChatMessage(role="assistant", content="DB_ASSISTANT_1"), + ChatMessage(role="user", content="GAP_USER_2"), + ChatMessage(role="assistant", content="GAP_ASSISTANT_3"), + ChatMessage(role="user", content="current turn"), + ], + title="test", + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + builder = TranscriptBuilder() + baseline_restore = TranscriptDownload( + content=content, + message_count=2, # watermark=2; session has 4 prior → gap of 2 + mode="baseline", + ) + + import backend.copilot.sdk.service as _svc_mod + + with ( + patch( + "backend.copilot.sdk.service.download_transcript", + new=AsyncMock(return_value=baseline_restore), + ), + patch.object(_svc_mod.config, "claude_agent_use_resume", True), + ): + result = await _restore_cli_session_for_turn( + user_id="user-1", + session_id="test-session", + session=session, + sdk_cwd=str(tmp_path), + transcript_builder=builder, + log_prefix="[Test]", + ) + + assert result.use_resume is False + assert result.context_messages is not None + # 2 from transcript + 2 gap messages = 4 total + assert len(result.context_messages) == 4 + roles = [m.role for m in result.context_messages] + assert roles == ["user", "assistant", "user", "assistant"] + # Gap messages come from DB (ChatMessage objects) + gap_user = result.context_messages[2] + gap_asst = result.context_messages[3] + assert gap_user.content == "GAP_USER_2" + assert gap_asst.content == "GAP_ASSISTANT_3" diff --git a/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py new file mode 100644 index 0000000000..592dbde82f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/test_transcript_watermark.py @@ -0,0 +1,95 @@ +"""Unit tests for the watermark-fix logic in stream_chat_completion_sdk. + +The fix is at the upload step: when use_resume=True and transcript_msg_count>0 +we set the JSONL coverage watermark to transcript_msg_count + 2 (the pair just +recorded) instead of len(session.messages). This prevents the "inflated +watermark" bug where a stale JSONL in GCS could hide missing context from +future gap-fill checks. +""" + +from __future__ import annotations + + +def _compute_jsonl_covered( + use_resume: bool, + transcript_msg_count: int, + session_msg_count: int, +) -> int: + """Mirror the watermark computation from ``stream_chat_completion_sdk``. + + Extracted here so we can unit-test it independently without invoking the + full streaming stack. + """ + if use_resume and transcript_msg_count > 0: + return transcript_msg_count + 2 + return session_msg_count + + +class TestWatermarkFix: + """Watermark computation logic — mirrors the finally-block in SDK service.""" + + def test_inflated_watermark_triggers_gap_fill(self): + """Stale JSONL (T12) with high watermark (46) → after fix, watermark=14. + + Before fix: watermark=46 → next turn's gap check (transcript_msg_count < db-1) + never fires because 46 >= 47-1=46, so context loss is silent. + After fix: watermark = 12 + 2 = 14 → gap check fires (14 < 46) and + the model receives the missing turns. + """ + # Simulate: use_resume=True, transcript covered T12 (12 msgs), DB now has 47 + use_resume = True + transcript_msg_count = 12 + session_msg_count = 47 # DB count (what old code used to set watermark) + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 14 # 12 + 2, NOT 47 + # Verify: the gap check would fire on next turn + # next-turn check: transcript_msg_count < msg_count - 1 → 14 < 47-1=46 → True + assert watermark < session_msg_count - 1 + + def test_no_false_positive_when_transcript_current(self): + """Transcript current (watermark=46, DB=47) → gap stays 0. + + When the JSONL actually covers T46 (the most recent assistant turn), + uploading watermark=46+2=48 means next turn's gap check sees + 48 >= 48-1=47 → no gap. Correct. + """ + use_resume = True + transcript_msg_count = 46 + session_msg_count = 47 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == 48 # 46 + 2 + # Next turn: session has 48 msgs, watermark=48 → 48 >= 48-1=47 → no gap + next_turn_session = 48 + assert watermark >= next_turn_session - 1 + + def test_fresh_session_falls_back_to_db_count(self): + """use_resume=False → watermark = len(session.messages) (original behaviour).""" + use_resume = False + transcript_msg_count = 0 + session_msg_count = 3 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count + + def test_old_format_meta_zero_count_falls_back_to_db(self): + """transcript_msg_count=0 (old-format meta with no count field) → DB fallback.""" + use_resume = True + transcript_msg_count = 0 # old-format meta or not-yet-set + session_msg_count = 10 + + watermark = _compute_jsonl_covered( + use_resume, transcript_msg_count, session_msg_count + ) + + assert watermark == session_msg_count diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript.py index cfbf01a466..d5cf3c3e94 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript.py @@ -12,18 +12,20 @@ from backend.copilot.transcript import ( ENTRY_TYPE_MESSAGE, STOP_REASON_END_TURN, STRIPPABLE_TYPES, - TRANSCRIPT_STORAGE_PREFIX, TranscriptDownload, + TranscriptMode, cleanup_stale_project_dirs, + cli_session_path, compact_transcript, delete_transcript, + detect_gap, download_transcript, + extract_context_messages, + projects_base, read_compacted_entries, - restore_cli_session, strip_for_upload, strip_progress_entries, strip_stale_thinking_blocks, - upload_cli_session, upload_transcript, validate_transcript, write_transcript_to_tempfile, @@ -34,18 +36,20 @@ __all__ = [ "ENTRY_TYPE_MESSAGE", "STOP_REASON_END_TURN", "STRIPPABLE_TYPES", - "TRANSCRIPT_STORAGE_PREFIX", "TranscriptDownload", + "TranscriptMode", "cleanup_stale_project_dirs", + "cli_session_path", "compact_transcript", "delete_transcript", + "detect_gap", "download_transcript", + "extract_context_messages", + "projects_base", "read_compacted_entries", - "restore_cli_session", "strip_for_upload", "strip_progress_entries", "strip_stale_thinking_blocks", - "upload_cli_session", "upload_transcript", "validate_transcript", "write_transcript_to_tempfile", diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py index 14e404a994..01f3540c28 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -297,8 +297,8 @@ class TestStripProgressEntries: class TestDeleteTranscript: @pytest.mark.asyncio - async def test_deletes_both_jsonl_and_meta(self): - """delete_transcript removes both the .jsonl and .meta.json files.""" + async def test_deletes_cli_session_and_meta(self): + """delete_transcript removes the CLI session .jsonl and .meta.json.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock() @@ -309,7 +309,7 @@ class TestDeleteTranscript: ): await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 paths = [call.args[0] for call in mock_storage.delete.call_args_list] assert any(p.endswith(".jsonl") for p in paths) assert any(p.endswith(".meta.json") for p in paths) @@ -319,7 +319,7 @@ class TestDeleteTranscript: """If .jsonl delete fails, .meta.json delete is still attempted.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[Exception("jsonl delete failed"), None, None] + side_effect=[Exception("jsonl delete failed"), None] ) with patch( @@ -330,14 +330,14 @@ class TestDeleteTranscript: # Should not raise await delete_transcript("user-123", "session-456") - assert mock_storage.delete.call_count == 3 + assert mock_storage.delete.call_count == 2 @pytest.mark.asyncio async def test_handles_meta_delete_failure(self): """If .meta.json delete fails, no exception propagates.""" mock_storage = AsyncMock() mock_storage.delete = AsyncMock( - side_effect=[None, Exception("meta delete failed"), None] + side_effect=[None, Exception("meta delete failed")] ) with patch( @@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1044,7 +1044,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1118,7 +1118,7 @@ class TestCleanupStaleProjectDirs: nonexistent = str(tmp_path / "does-not-exist" / "projects") monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: nonexistent, ) @@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1165,7 +1165,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs: projects_dir = tmp_path / "projects" projects_dir.mkdir() monkeypatch.setattr( - "backend.copilot.transcript._projects_base", + "backend.copilot.transcript.projects_base", lambda: str(projects_dir), ) @@ -1368,3 +1368,172 @@ class TestStripStaleThinkingBlocks: # Both entries of last turn (msg_last) preserved assert lines[1]["message"]["content"][0]["type"] == "thinking" assert lines[2]["message"]["content"][0]["type"] == "text" + + +class TestProcessCliRestore: + """``process_cli_restore`` validates, strips, and writes CLI session to disk.""" + + def test_writes_stripped_bytes_not_raw(self, tmp_path): + """Stripped bytes (not raw bytes) must be written to disk for --resume.""" + import os + import re + from pathlib import Path + from unittest.mock import patch + + from backend.copilot.sdk.service import process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + session_id = "12345678-0000-0000-0000-abcdef000001" + sdk_cwd = str(tmp_path) + projects_base_dir = str(tmp_path) + + # Build raw content with a strippable progress entry + a valid user/assistant pair + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + raw_bytes = raw_content.encode("utf-8") + restore = TranscriptDownload(content=raw_bytes, message_count=2, mode="sdk") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + stripped_str, ok = process_cli_restore( + restore, sdk_cwd, session_id, "[Test]" + ) + + assert ok, "Expected successful restore" + + # Find the written session file + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_file = Path(projects_base_dir) / encoded_cwd / f"{session_id}.jsonl" + assert session_file.exists(), "Session file should have been written" + + written_bytes = session_file.read_bytes() + # The written bytes must be the stripped version (no progress entry) + assert ( + b"progress" not in written_bytes + ), "Raw bytes with progress entry should not have been written" + assert ( + b"hello" in written_bytes + ), "Stripped content should still contain assistant turn" + + # Written bytes must equal the stripped string re-encoded + assert written_bytes == stripped_str.encode( + "utf-8" + ), "Written bytes must equal stripped content" + + def test_invalid_content_returns_false(self): + """Content that fails validation after strip returns (empty, False).""" + from backend.copilot.sdk.service import process_cli_restore + from backend.copilot.transcript import TranscriptDownload + + # A single progress-only entry — stripped result will be empty/invalid + raw_content = '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + restore = TranscriptDownload( + content=raw_content.encode("utf-8"), message_count=1, mode="sdk" + ) + + stripped_str, ok = process_cli_restore( + restore, + "/tmp/nonexistent-sdk-cwd", + "12345678-0000-0000-0000-000000000099", + "[Test]", + ) + + assert not ok + assert stripped_str == "" + + +class TestReadCliSessionFromDisk: + """``read_cli_session_from_disk`` reads, strips, and optionally writes back the session.""" + + def _build_session_file(self, tmp_path, session_id: str): + """Build the session file path inside tmp_path using the same encoding as cli_session_path.""" + import os + import re + from pathlib import Path + + sdk_cwd = str(tmp_path) + encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) + session_dir = Path(str(tmp_path)) / encoded_cwd + session_dir.mkdir(parents=True, exist_ok=True) + return sdk_cwd, session_dir / f"{session_id}.jsonl" + + def test_returns_raw_bytes_for_invalid_utf8(self, tmp_path): + """Non-UTF-8 bytes trigger UnicodeDecodeError — returns raw bytes (upload-raw fallback).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0001" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Write raw invalid UTF-8 bytes + session_file.write_bytes(b"\xff\xfe invalid utf-8\n") + + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + + # UnicodeDecodeError path returns the raw bytes (upload-raw fallback) + assert result == b"\xff\xfe invalid utf-8\n" + + def test_write_back_oserror_still_returns_stripped_bytes(self, tmp_path): + """OSError on write-back returns stripped bytes for GCS upload (not raw).""" + from unittest.mock import patch + + from backend.copilot.sdk.service import read_cli_session_from_disk + + session_id = "12345678-0000-0000-0000-aabbccdd0002" + projects_base_dir = str(tmp_path) + sdk_cwd, session_file = self._build_session_file(tmp_path, session_id) + + # Content with a strippable progress entry so stripped_bytes < raw_bytes + raw_content = ( + '{"type":"progress","uuid":"p1","subtype":"agent_progress","parentUuid":null}\n' + '{"type":"user","uuid":"u1","parentUuid":null,"message":{"role":"user","content":"hi"}}\n' + '{"type":"assistant","uuid":"a1","parentUuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello"}]}}\n' + ) + session_file.write_bytes(raw_content.encode("utf-8")) + # Make the file read-only so write_bytes raises OSError on the write-back + session_file.chmod(0o444) + + try: + with ( + patch( + "backend.copilot.sdk.service.projects_base", + return_value=projects_base_dir, + ), + patch( + "backend.copilot.transcript.projects_base", + return_value=projects_base_dir, + ), + ): + result = read_cli_session_from_disk(sdk_cwd, session_id, "[Test]") + finally: + session_file.chmod(0o644) + + # Must return stripped bytes (not raw, not None) so GCS gets the clean version + assert result is not None + assert ( + b"progress" not in result + ), "Stripped bytes must not contain progress entry" + assert b"hello" in result, "Stripped bytes should contain assistant turn" diff --git a/autogpt_platform/backend/backend/copilot/service_test.py b/autogpt_platform/backend/backend/copilot/service_test.py index c4b1c3182e..ec9b13fb22 100644 --- a/autogpt_platform/backend/backend/copilot/service_test.py +++ b/autogpt_platform/backend/backend/copilot/service_test.py @@ -61,18 +61,23 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id): # (CLI version, platform). When that happens, multi-turn still works # via conversation compression (non-resume path), but we can't test # the --resume round-trip. - transcript = None + cli_session = None for _ in range(10): await asyncio.sleep(0.5) - transcript = await download_transcript(test_user_id, session.session_id) - if transcript: + cli_session = await download_transcript(test_user_id, session.session_id) + # Wait until both the session bytes AND the message_count watermark are + # present — a session with message_count=0 means the .meta.json hasn't + # been uploaded yet, so --resume on the next turn would skip gap-fill. + if cli_session and cli_session.message_count > 0: break - if not transcript: + if not cli_session: return pytest.skip( "CLI did not produce a usable transcript — " "cannot test --resume round-trip in this environment" ) - logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes") + logger.info( + f"Turn 1 CLI session uploaded: {len(cli_session.content)} bytes, msg_count={cli_session.message_count}" + ) # Reload session for turn 2 session = await get_chat_session(session.session_id, test_user_id) diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 030763dbca..02fa21b574 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -423,20 +423,33 @@ async def subscribe_to_session( extra={"json_fields": {**log_meta, "duration_ms": hgetall_time}}, ) - # RACE CONDITION FIX: If session not found, retry once after small delay - # This handles the case where subscribe_to_session is called immediately - # after create_session but before Redis propagates the write + # RACE CONDITION FIX: If session not found, retry with backoff. + # Duplicate requests skip create_session and subscribe immediately; the + # original request's create_session (a Redis hset) may not have completed + # yet. 3 × 100ms gives a 300ms window which covers DB-write latency on the + # original request before the hset even starts. if not meta: - logger.warning( - "[TIMING] Session not found on first attempt, retrying after 50ms delay", - extra={"json_fields": {**log_meta}}, - ) - await asyncio.sleep(0.05) # 50ms - meta = await redis.hgetall(meta_key) # type: ignore[misc] - if not meta: + _max_retries = 3 + _retry_delay = 0.1 # 100ms per attempt + for attempt in range(_max_retries): + logger.warning( + f"[TIMING] Session not found (attempt {attempt + 1}/{_max_retries}), " + f"retrying after {int(_retry_delay * 1000)}ms", + extra={"json_fields": {**log_meta, "attempt": attempt + 1}}, + ) + await asyncio.sleep(_retry_delay) + meta = await redis.hgetall(meta_key) # type: ignore[misc] + if meta: + logger.info( + f"[TIMING] Session found after {attempt + 1} retries", + extra={"json_fields": {**log_meta, "attempts": attempt + 1}}, + ) + break + else: elapsed = (time.perf_counter() - start_time) * 1000 logger.info( - f"[TIMING] Session still not found in Redis after retry ({elapsed:.1f}ms total)", + f"[TIMING] Session still not found in Redis after {_max_retries} retries " + f"({elapsed:.1f}ms total)", extra={ "json_fields": { **log_meta, @@ -446,10 +459,6 @@ async def subscribe_to_session( }, ) return None - logger.info( - "[TIMING] Session found after retry", - extra={"json_fields": {**log_meta}}, - ) # Note: Redis client uses decode_responses=True, so keys are strings session_status = meta.get("status", "") diff --git a/autogpt_platform/backend/backend/copilot/transcript.py b/autogpt_platform/backend/backend/copilot/transcript.py index ea1bc2e81c..c4d3de28af 100644 --- a/autogpt_platform/backend/backend/copilot/transcript.py +++ b/autogpt_platform/backend/backend/copilot/transcript.py @@ -1,10 +1,10 @@ """JSONL transcript management for stateless multi-turn resume. The Claude Code CLI persists conversations as JSONL files (one JSON object per -line). When the SDK's ``Stop`` hook fires we read this file, strip bloat -(progress entries, metadata), and upload the result to bucket storage. On the -next turn we download the transcript, write it to a temp file, and pass -``--resume`` so the CLI can reconstruct the full conversation. +line). When the SDK's ``Stop`` hook fires the caller reads this file, strips +bloat (progress entries, metadata), and uploads the result to bucket storage. +On the next turn the caller downloads the bytes and writes them to disk before +passing ``--resume`` so the CLI can reconstruct the full conversation. Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local filesystem for self-hosted) — no DB column needed. @@ -20,6 +20,7 @@ import shutil import time from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING, Literal from uuid import uuid4 from backend.util import json @@ -27,6 +28,9 @@ from backend.util.clients import get_openai_client from backend.util.prompt import CompressResult, compress_context from backend.util.workspace_storage import GCSWorkspaceStorage, get_workspace_storage +if TYPE_CHECKING: + from .model import ChatMessage + logger = logging.getLogger(__name__) # UUIDs are hex + hyphens; strip everything else to prevent path injection. @@ -44,17 +48,17 @@ STRIPPABLE_TYPES = frozenset( ) +TranscriptMode = Literal["sdk", "baseline"] + + @dataclass class TranscriptDownload: - """Result of downloading a transcript with its metadata.""" - - content: str - message_count: int = 0 # session.messages length when uploaded - uploaded_at: float = 0.0 # epoch timestamp of upload + content: bytes | str + message_count: int = 0 + # "sdk" = Claude CLI native, "baseline" = TranscriptBuilder + mode: TranscriptMode = "sdk" -# Workspace storage constants — deterministic path from session_id. -TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts" # Storage prefix for the CLI's native session JSONL files (for cross-pod --resume). _CLI_SESSION_STORAGE_PREFIX = "cli-sessions" @@ -363,7 +367,7 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str: _SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-") -def _projects_base() -> str: +def projects_base() -> str: """Return the resolved path to the CLI's projects directory.""" config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude") return os.path.realpath(os.path.join(config_dir, "projects")) @@ -390,8 +394,8 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: Returns the number of directories removed. """ - projects_base = _projects_base() - if not os.path.isdir(projects_base): + _pbase = projects_base() + if not os.path.isdir(_pbase): return 0 now = time.time() @@ -399,7 +403,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Scoped mode: only clean up the one directory for the current session. if encoded_cwd: - target = Path(projects_base) / encoded_cwd + target = Path(_pbase) / encoded_cwd if not target.is_dir(): return 0 # Guard: only sweep copilot-generated dirs. @@ -437,7 +441,7 @@ def cleanup_stale_project_dirs(encoded_cwd: str | None = None) -> int: # Only safe for single-tenant deployments; callers should prefer the # scoped variant by passing encoded_cwd. try: - entries = Path(projects_base).iterdir() + entries = Path(_pbase).iterdir() except OSError as e: logger.warning("[Transcript] Failed to list projects dir: %s", e) return 0 @@ -490,9 +494,9 @@ def read_compacted_entries(transcript_path: str) -> list[dict] | None: if not transcript_path: return None - projects_base = _projects_base() + _pbase = projects_base() real_path = os.path.realpath(transcript_path) - if not real_path.startswith(projects_base + os.sep): + if not real_path.startswith(_pbase + os.sep): logger.warning( "[Transcript] transcript_path outside projects base: %s", transcript_path ) @@ -611,28 +615,6 @@ def validate_transcript(content: str | None) -> bool: # --------------------------------------------------------------------------- -def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript. - - Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl`` - IDs are sanitized to hex+hyphen to prevent path traversal. - """ - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.jsonl", - ) - - -def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: - """Return (workspace_id, file_id, filename) for a session's transcript metadata.""" - return ( - TRANSCRIPT_STORAGE_PREFIX, - _sanitize_id(user_id), - f"{_sanitize_id(session_id)}.meta.json", - ) - - def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: """Build a full storage path from (workspace_id, file_id, filename) parts.""" wid, fid, fname = parts @@ -642,24 +624,12 @@ def _build_path_from_parts(parts: tuple[str, str, str], backend: object) -> str: return f"local://{wid}/{fid}/{fname}" -def _build_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path string that ``retrieve()`` expects.""" - return _build_path_from_parts(_storage_path_parts(user_id, session_id), backend) - - -def _build_meta_storage_path(user_id: str, session_id: str, backend: object) -> str: - """Build the full storage path for the companion .meta.json file.""" - return _build_path_from_parts( - _meta_storage_path_parts(user_id, session_id), backend - ) - - # --------------------------------------------------------------------------- # CLI native session file — cross-pod --resume support # --------------------------------------------------------------------------- -def _cli_session_path(sdk_cwd: str, session_id: str) -> str: +def cli_session_path(sdk_cwd: str, session_id: str) -> str: """Expected path of the CLI's native session JSONL file. The CLI resolves the working directory via ``os.path.realpath``, then @@ -675,7 +645,7 @@ def _cli_session_path(sdk_cwd: str, session_id: str) -> str: """ encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) safe_id = _sanitize_id(session_id) - return os.path.join(_projects_base(), encoded_cwd, f"{safe_id}.jsonl") + return os.path.join(projects_base(), encoded_cwd, f"{safe_id}.jsonl") def _cli_session_storage_path_parts( @@ -689,235 +659,82 @@ def _cli_session_storage_path_parts( ) -async def upload_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> None: - """Upload the CLI's native session JSONL file to remote storage. - - Called after each turn so the next turn can restore the file on any pod - (eliminating the pod-affinity requirement for --resume). - - The CLI only writes the session file after the turn completes, so this - must run in the finally block, AFTER the SDK stream has finished. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session file outside projects base, skipping upload: %s", - log_prefix, - os.path.basename(real_path), - ) - return - - try: - raw_bytes = Path(real_path).read_bytes() - except FileNotFoundError: - logger.debug( - "%s CLI session file not found, skipping upload: %s", - log_prefix, - session_file, - ) - return - except OSError as e: - logger.warning("%s Failed to read CLI session file: %s", log_prefix, e) - return - - # Strip stale thinking blocks and metadata entries (progress, file-history-snapshot, - # queue-operation) from the CLI session before writing it back locally and uploading - # to GCS. Thinking blocks from non-last assistant turns are not needed for --resume - # but can be massive (tens of thousands of tokens each), causing the CLI to auto-compact - # its session when the context window fills up. Stripping keeps the session well below - # the ~200K-token compaction threshold and prevents silent context loss. - try: - raw_text = raw_bytes.decode("utf-8") - stripped_text = strip_for_upload(raw_text) - stripped_bytes = stripped_text.encode("utf-8") - if len(stripped_bytes) < len(raw_bytes): - # Write the stripped version back locally so same-pod turns also benefit. - Path(real_path).write_bytes(stripped_bytes) - logger.info( - "%s Stripped CLI session file: %dB → %dB", - log_prefix, - len(raw_bytes), - len(stripped_bytes), - ) - content = stripped_bytes - except Exception as e: - logger.warning( - "%s Failed to strip CLI session file, uploading raw: %s", log_prefix, e - ) - content = raw_bytes - - storage = await get_workspace_storage() - wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) - try: - await storage.store( - workspace_id=wid, file_id=fid, filename=fname, content=content - ) - logger.info( - "%s Uploaded CLI session file (%dB) for cross-pod --resume", - log_prefix, - len(content), - ) - except Exception as e: - logger.warning("%s Failed to upload CLI session file: %s", log_prefix, e) - - -async def restore_cli_session( - user_id: str, - session_id: str, - sdk_cwd: str, - log_prefix: str = "[Transcript]", -) -> bool: - """Download and restore the CLI's native session file for --resume. - - Returns True if the file was successfully restored and --resume can be - used with the session UUID. Returns False if not available (first turn - or upload failed), in which case the caller should not set --resume. - """ - session_file = _cli_session_path(sdk_cwd, session_id) - real_path = os.path.realpath(session_file) - projects_base = _projects_base() - - if not real_path.startswith(projects_base + os.sep): - logger.warning( - "%s CLI session restore path outside projects base: %s", - log_prefix, - os.path.basename(session_file), - ) - return False - - # If the session file already exists locally (same-pod reuse), use it directly. - # Downloading from storage could overwrite a newer local version when a previous - # turn's upload failed: stored content is stale while the local file already - # contains extended history from that turn. - if Path(real_path).exists(): - logger.debug( - "%s CLI session file already exists locally — using it for --resume", - log_prefix, - ) - return True - - storage = await get_workspace_storage() - path = _build_path_from_parts( - _cli_session_storage_path_parts(user_id, session_id), storage +def _cli_session_meta_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]: + """Return (workspace_id, file_id, filename) for the CLI session meta file.""" + return ( + _CLI_SESSION_STORAGE_PREFIX, + _sanitize_id(user_id), + f"{_sanitize_id(session_id)}.meta.json", ) - try: - content = await storage.retrieve(path) - except FileNotFoundError: - logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) - return False - except Exception as e: - logger.warning("%s Failed to download CLI session: %s", log_prefix, e) - return False - - try: - os.makedirs(os.path.dirname(real_path), exist_ok=True) - Path(real_path).write_bytes(content) - logger.info( - "%s Restored CLI session file (%dB) for --resume", - log_prefix, - len(content), - ) - return True - except OSError as e: - logger.warning("%s Failed to write CLI session file: %s", log_prefix, e) - return False - async def upload_transcript( user_id: str, session_id: str, - content: str, + content: bytes, message_count: int = 0, + mode: TranscriptMode = "sdk", log_prefix: str = "[Transcript]", - skip_strip: bool = False, ) -> None: - """Strip progress entries and stale thinking blocks, then upload transcript. + """Upload CLI session content to GCS with companion meta.json. - The transcript represents the FULL active context (atomic). - Each upload REPLACES the previous transcript entirely. + Pure GCS operation — no disk I/O. The caller is responsible for reading + the session file from disk before calling this function. - The executor holds a cluster lock per session, so concurrent uploads for - the same session cannot happen. + Also uploads a companion .meta.json with the message_count watermark so + download_transcript can return it without a separate fetch. - Args: - content: Complete JSONL transcript (from TranscriptBuilder). - message_count: ``len(session.messages)`` at upload time. - skip_strip: When ``True``, skip the strip + re-validate pass. - Safe for builder-generated content (baseline path) which - never emits progress entries or stale thinking blocks. + Called after each turn so the next turn can restore the file on any pod + (eliminating the pod-affinity requirement for --resume). """ - if skip_strip: - # Caller guarantees the content is already clean and valid. - stripped = content - else: - # Strip metadata entries and stale thinking blocks in a single parse. - # SDK-built transcripts may have progress entries; strip for safety. - stripped = strip_for_upload(content) - if not skip_strip and not validate_transcript(stripped): - # Log entry types for debugging — helps identify why validation failed - entry_types = [ - json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?") - for line in stripped.strip().split("\n") - ] - logger.warning( - "%s Skipping upload — stripped content not valid " - "(types=%s, stripped_len=%d, raw_len=%d)", - log_prefix, - entry_types, - len(stripped), - len(content), - ) - logger.debug("%s Raw content preview: %s", log_prefix, content[:500]) - logger.debug("%s Stripped content: %s", log_prefix, stripped[:500]) - return - storage = await get_workspace_storage() - wid, fid, fname = _storage_path_parts(user_id, session_id) - encoded = stripped.encode("utf-8") - meta = {"message_count": message_count, "uploaded_at": time.time()} - mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id) + wid, fid, fname = _cli_session_storage_path_parts(user_id, session_id) + mwid, mfid, mfname = _cli_session_meta_path_parts(user_id, session_id) + meta = {"message_count": message_count, "mode": mode, "uploaded_at": time.time()} meta_encoded = json.dumps(meta).encode("utf-8") - # Transcript + metadata are independent objects at different keys, so - # write them concurrently. ``return_exceptions`` keeps a metadata - # failure from sinking the transcript write. - transcript_result, metadata_result = await asyncio.gather( - storage.store( - workspace_id=wid, - file_id=fid, - filename=fname, - content=encoded, - ), - storage.store( - workspace_id=mwid, - file_id=mfid, - filename=mfname, - content=meta_encoded, - ), - return_exceptions=True, - ) - if isinstance(transcript_result, BaseException): - raise transcript_result - if isinstance(metadata_result, BaseException): - # Metadata is best-effort — the gap-fill logic in - # _build_query_message tolerates a missing metadata file. - logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result) + # Write JSONL first, meta second — sequential so a crash between the two + # leaves an orphaned JSONL (no meta) rather than an orphaned meta (wrong + # watermark / mode paired with stale or absent content). + # On any failure we roll back the other file so the pair is always absent + # together; download_transcript returns None when either file is missing. + try: + await storage.store( + workspace_id=wid, file_id=fid, filename=fname, content=content + ) + except Exception as session_err: + logger.warning( + "%s Failed to upload CLI session file: %s", log_prefix, session_err + ) + return + + try: + await storage.store( + workspace_id=mwid, file_id=mfid, filename=mfname, content=meta_encoded + ) + except Exception as meta_err: + logger.warning("%s Failed to upload CLI session meta: %s", log_prefix, meta_err) + # Roll back the JSONL so neither file exists — avoids orphaned JSONL being + # used with wrong mode/watermark defaults on the next restore. + try: + session_path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + await storage.delete(session_path) + except Exception as rollback_err: + logger.debug( + "%s Session rollback failed (harmless — download will return None): %s", + log_prefix, + rollback_err, + ) + return logger.info( - "%s Uploaded %dB (stripped from %dB, msg_count=%d)", + "%s Uploaded CLI session (%dB, msg_count=%d, mode=%s)", log_prefix, - len(encoded), len(content), message_count, + mode, ) @@ -926,83 +743,173 @@ async def download_transcript( session_id: str, log_prefix: str = "[Transcript]", ) -> TranscriptDownload | None: - """Download transcript and metadata from bucket storage. + """Download CLI session from GCS. Returns content + message_count + mode, or None if not found. - Returns a ``TranscriptDownload`` with the JSONL content and the - ``message_count`` watermark from the upload, or ``None`` if not found. + Pure GCS operation — no disk I/O. The caller is responsible for writing + content to disk if --resume is needed. - The content and metadata fetches run concurrently since they are - independent objects in the bucket. + Returns a TranscriptDownload with the raw content, message_count watermark, + and mode on success, or None if not available (first turn or upload failed). """ storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - meta_path = _build_meta_storage_path(user_id, session_id, storage) + path = _build_path_from_parts( + _cli_session_storage_path_parts(user_id, session_id), storage + ) + meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) - content_task = asyncio.create_task(storage.retrieve(path)) - meta_task = asyncio.create_task(storage.retrieve(meta_path)) content_result, meta_result = await asyncio.gather( - content_task, meta_task, return_exceptions=True + storage.retrieve(path), + storage.retrieve(meta_path), + return_exceptions=True, ) if isinstance(content_result, FileNotFoundError): - logger.debug("%s No transcript in storage", log_prefix) + logger.debug("%s No CLI session in storage (first turn or missing)", log_prefix) return None if isinstance(content_result, BaseException): logger.warning( - "%s Failed to download transcript: %s", log_prefix, content_result + "%s Failed to download CLI session: %s", log_prefix, content_result ) return None - content = content_result.decode("utf-8") + content: bytes = content_result - # Metadata is best-effort — old transcripts won't have it. + # Parse message_count and mode from companion meta — best-effort, defaults. message_count = 0 - uploaded_at = 0.0 + mode: TranscriptMode = "sdk" if isinstance(meta_result, FileNotFoundError): - pass # No metadata — treat as unknown (msg_count=0 → always fill gap) + pass # No meta — old upload; default to "sdk" elif isinstance(meta_result, BaseException): - logger.debug( - "%s Failed to load transcript metadata: %s", log_prefix, meta_result - ) + logger.debug("%s Failed to load CLI session meta: %s", log_prefix, meta_result) else: - meta = json.loads(meta_result.decode("utf-8"), fallback={}) - message_count = meta.get("message_count", 0) - uploaded_at = meta.get("uploaded_at", 0.0) + try: + meta_str = meta_result.decode("utf-8") + except UnicodeDecodeError: + logger.debug("%s CLI session meta is not valid UTF-8, ignoring", log_prefix) + meta_str = None + if meta_str is not None: + meta = json.loads(meta_str, fallback={}) + if isinstance(meta, dict): + raw_count = meta.get("message_count", 0) + message_count = ( + raw_count if isinstance(raw_count, int) and raw_count >= 0 else 0 + ) + raw_mode = meta.get("mode", "sdk") + mode = raw_mode if raw_mode in ("sdk", "baseline") else "sdk" logger.info( - "%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count - ) - return TranscriptDownload( - content=content, - message_count=message_count, - uploaded_at=uploaded_at, + "%s Downloaded CLI session (%dB, msg_count=%d, mode=%s)", + log_prefix, + len(content), + message_count, + mode, ) + return TranscriptDownload(content=content, message_count=message_count, mode=mode) + + +def detect_gap( + download: TranscriptDownload, + session_messages: list[ChatMessage], +) -> list[ChatMessage]: + """Return chat-db messages after the transcript watermark (excluding current user turn). + + Returns [] if transcript is current, watermark is zero, or the watermark + position doesn't end on an assistant turn (misaligned watermark). + """ + if download.message_count == 0: + return [] + wm = download.message_count + total = len(session_messages) + if wm >= total - 1: + return [] + # Sanity: position wm-1 should be an assistant turn; misaligned watermark + # means the DB messages shifted (e.g. deletion) — skip gap to avoid wrong context. + # In normal operation ``message_count`` is always written after a complete + # user→assistant exchange (never mid-turn), so the last covered position is + # always assistant. This guard fires only on data corruption or message deletion. + if session_messages[wm - 1].role != "assistant": + return [] + return list(session_messages[wm : total - 1]) + + +def extract_context_messages( + download: TranscriptDownload | None, + session_messages: "list[ChatMessage]", +) -> "list[ChatMessage]": + """Return context messages for the current turn: transcript content + gap. + + This is the shared context primitive used by both the SDK path + (``use_resume=False`` → ```` injection) and the + baseline path (OpenAI messages array). + + How it works: + + - When a transcript exists, ``TranscriptBuilder.load_previous`` preserves + ``isCompactSummary=True`` compaction entries, so the returned messages + mirror the compacted context the CLI would see via ``--resume``. + - The gap (DB messages after the transcript watermark) is always small in + normal operation; it only grows during mode switches or when an upload + was missed. + - Falls back to full DB messages when no transcript exists (first turn, + upload failure, or GCS unavailable). + - Returns *prior* messages only (excluding the current user turn at + ``session_messages[-1]``). Callers that need the current turn append + ``session_messages[-1]`` themselves. + - **Tool calls from transcript entries are flattened to text**: assistant + messages derived from the JSONL use ``_flatten_assistant_content``, which + serialises ``tool_use`` blocks as human-readable text rather than + structured ``tool_calls``. Gap messages (from DB) preserve their + original ``tool_calls`` field. This is the same trade-off as the old + ``_compress_session_messages(session.messages)`` approach — no regression. + + Args: + download: The ``TranscriptDownload`` from GCS, or ``None`` when no + transcript is available. ``content`` may be either ``bytes`` or + ``str`` (the baseline path decodes + strips before returning). + session_messages: All messages in the session, with the current user + turn as the last element. + + Returns: + A list of ``ChatMessage`` objects covering the prior conversation + context, suitable for injection as conversation history. + """ + from .model import ChatMessage as _ChatMessage # runtime import + + prior = session_messages[:-1] + + if download is None: + return prior + + raw_content = download.content + if not raw_content: + return prior + + # Handle both bytes (raw GCS download) and str (pre-decoded baseline path). + if isinstance(raw_content, bytes): + try: + content_str: str = raw_content.decode("utf-8") + except UnicodeDecodeError: + return prior + else: + content_str = raw_content + + raw = _transcript_to_messages(content_str) + if not raw: + return prior + + transcript_msgs = [ + _ChatMessage(role=m["role"], content=m.get("content") or "") for m in raw + ] + gap = detect_gap(download, session_messages) + return transcript_msgs + gap async def delete_transcript(user_id: str, session_id: str) -> None: - """Delete transcript and its metadata from bucket storage. - - Removes both the ``.jsonl`` transcript and the companion ``.meta.json`` - so stale ``message_count`` watermarks cannot corrupt gap-fill logic. - """ + """Delete CLI session JSONL and its companion .meta.json from bucket storage.""" storage = await get_workspace_storage() - path = _build_storage_path(user_id, session_id, storage) - try: - await storage.delete(path) - logger.info("[Transcript] Deleted transcript for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete transcript: %s", e) - - # Also delete the companion .meta.json to avoid orphaned metadata. - try: - meta_path = _build_meta_storage_path(user_id, session_id, storage) - await storage.delete(meta_path) - logger.info("[Transcript] Deleted metadata for session %s", session_id) - except Exception as e: - logger.warning("[Transcript] Failed to delete metadata: %s", e) - - # Also delete the CLI native session file to prevent storage growth. try: cli_path = _build_path_from_parts( _cli_session_storage_path_parts(user_id, session_id), storage @@ -1012,6 +919,15 @@ async def delete_transcript(user_id: str, session_id: str) -> None: except Exception as e: logger.warning("[Transcript] Failed to delete CLI session: %s", e) + try: + cli_meta_path = _build_path_from_parts( + _cli_session_meta_path_parts(user_id, session_id), storage + ) + await storage.delete(cli_meta_path) + logger.info("[Transcript] Deleted CLI session meta for session %s", session_id) + except Exception as e: + logger.warning("[Transcript] Failed to delete CLI session meta: %s", e) + # --------------------------------------------------------------------------- # Transcript compaction — LLM summarization for prompt-too-long recovery diff --git a/autogpt_platform/backend/backend/copilot/transcript_test.py b/autogpt_platform/backend/backend/copilot/transcript_test.py index 88be88b07a..dde07a063e 100644 --- a/autogpt_platform/backend/backend/copilot/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/transcript_test.py @@ -16,11 +16,11 @@ from .transcript import ( _flatten_assistant_content, _flatten_tool_result_content, _messages_to_transcript, - _meta_storage_path_parts, _rechain_tail, _sanitize_id, - _storage_path_parts, _transcript_to_messages, + detect_gap, + extract_context_messages, strip_for_upload, validate_transcript, ) @@ -64,24 +64,6 @@ class TestSanitizeId: assert _sanitize_id("!@#$%^&*()") == "unknown" -# --------------------------------------------------------------------------- -# _storage_path_parts / _meta_storage_path_parts -# --------------------------------------------------------------------------- - - -class TestStoragePathParts: - def test_returns_triple(self): - prefix, uid, fname = _storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert "e" in uid # hex chars from "user-1" sanitized - assert fname.endswith(".jsonl") - - def test_meta_returns_meta_json(self): - prefix, _, fname = _meta_storage_path_parts("user-1", "sess-2") - assert prefix == "chat-transcripts" - assert fname.endswith(".meta.json") - - # --------------------------------------------------------------------------- # _build_path_from_parts # --------------------------------------------------------------------------- @@ -103,24 +85,6 @@ class TestBuildPathFromParts: assert path == "local://wid/fid/file.jsonl" -# --------------------------------------------------------------------------- -# TranscriptDownload dataclass -# --------------------------------------------------------------------------- - - -class TestTranscriptDownload: - def test_defaults(self): - td = TranscriptDownload(content="hello") - assert td.content == "hello" - assert td.message_count == 0 - assert td.uploaded_at == 0.0 - - def test_custom_values(self): - td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45) - assert td.message_count == 5 - assert td.uploaded_at == 123.45 - - # --------------------------------------------------------------------------- # _flatten_assistant_content # --------------------------------------------------------------------------- @@ -733,215 +697,194 @@ class TestValidateTranscript: class TestCliSessionPath: def test_encodes_slashes_to_dashes(self): - from .transcript import _cli_session_path, _projects_base + from .transcript import cli_session_path, projects_base sdk_cwd = "/tmp/copilot-abc" - result = _cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") - base = _projects_base() + result = cli_session_path(sdk_cwd, "12345678-1234-1234-1234-123456789abc") + base = projects_base() assert result.startswith(base) # Encoded cwd replaces '/' with '-' assert "-tmp-copilot-abc" in result assert result.endswith(".jsonl") def test_sanitizes_session_id(self): - from .transcript import _cli_session_path + from .transcript import cli_session_path - result = _cli_session_path("/tmp/cwd", "../../etc/passwd") + result = cli_session_path("/tmp/cwd", "../../etc/passwd") # _sanitize_id strips non-hex/hyphen chars; path traversal impossible assert ".." not in result assert "passwd" not in result class TestUploadCliSession: - def test_skips_upload_when_path_outside_projects_base(self, tmp_path): - """Files outside the CLI projects base are rejected without upload.""" + def test_uploads_content_bytes_successfully(self): + """Happy path: content bytes are stored as jsonl + meta.json.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path that is genuinely outside tmp_path so that - # realpath(session_file).startswith(projects_base + "/") is False - # and the boundary guard actually fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), + session_id="12345678-0000-0000-0000-000000000001", + content=content, ) ) - # storage.store must NOT be called — boundary guard should reject the path - mock_storage.store.assert_not_called() + # Two calls expected: session JSONL + companion .meta.json + assert mock_storage.store.call_count == 2 - def test_skips_upload_when_file_not_found(self, tmp_path): - """Missing CLI session file logs debug and skips upload silently.""" + def test_uploads_companion_meta_json_with_message_count(self): + """upload_transcript stores a companion .meta.json with message_count.""" import asyncio + import json from unittest.mock import AsyncMock, patch - from .transcript import upload_cli_session + from .transcript import upload_transcript mock_storage = AsyncMock() - projects_base = str(tmp_path) + content = b'{"type":"assistant"}\n' - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, ): - # session file doesn't exist — should not raise asyncio.run( - upload_cli_session( + upload_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), + session_id="12345678-0000-0000-0000-000000000010", + content=content, + message_count=5, ) ) - mock_storage.store.assert_not_called() + assert mock_storage.store.call_count == 2 + # Find the meta.json store call + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["message_count"] == 5 - def test_uploads_file_successfully(self, tmp_path): - """Happy path: session file exists within projects base → upload called.""" - import asyncio - from unittest.mock import AsyncMock, patch + def test_skips_upload_on_storage_failure(self): + """Storage exception on jsonl write is logged and does not propagate. - from .transcript import _sanitize_id, upload_cli_session - - projects_base = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000001" - sdk_cwd = str(tmp_path) - - # Build the path the same way _cli_session_path does, but using our tmp_path - # as projects_base so the boundary check passes. - # Must use the same encoding: re.sub non-alphanumeric → "-" on realpath. - import os - import re - - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') - - mock_storage = AsyncMock() - - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - mock_storage.store.assert_called_once() - - def test_skips_upload_on_oserror(self, tmp_path): - """OSError reading session file is logged as warning; upload is skipped.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import _sanitize_id, upload_cli_session - - projects_base = str(tmp_path) - sdk_cwd = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000002" - - # Build file at a path inside projects_base so boundary check passes. - import os - import re - - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - session_file.write_bytes(b'{"type":"assistant"}\n') - # Remove read permission to trigger OSError - session_file.chmod(0o000) - - mock_storage = AsyncMock() - - try: - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - finally: - session_file.chmod(0o644) # restore so tmp_path cleanup works - - mock_storage.store.assert_not_called() - - def test_strips_session_before_upload_and_writes_back(self, tmp_path): - """Strippable entries (progress, thinking blocks) are removed before upload. - - The stripped content is written back to disk (so same-pod turns benefit) - and the smaller bytes are uploaded to GCS. + With sequential writes, JSONL failure returns early — meta store is + never called, so no rollback is needed. """ import asyncio - import os - import re from unittest.mock import AsyncMock, patch - from .transcript import _sanitize_id, upload_cli_session + from .transcript import upload_transcript - projects_base = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000010" - sdk_cwd = str(tmp_path) + mock_storage = AsyncMock() + mock_storage.store.side_effect = RuntimeError("gcs unavailable") + content = b'{"type":"assistant"}\n' - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + # Should not raise — failures are logged as warnings + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000002", + content=content, + ) + ) - # A CLI session with a progress entry (strippable) and a real assistant message. + # Only one store call attempted (the JSONL); meta never reached + mock_storage.store.assert_called_once() + mock_storage.delete.assert_not_called() + + def test_rolls_back_session_when_meta_upload_fails(self): + """When meta upload fails after JSONL succeeds, JSONL is rolled back. + + Guarantees the pair is either both present or both absent — avoids an + orphaned JSONL being used with wrong mode/watermark defaults. + """ + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + # First store (JSONL) succeeds; second store (meta) fails + mock_storage.store.side_effect = [None, RuntimeError("meta write failed")] + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000099", + content=content, + ) + ) + + # Both store calls were attempted (JSONL then meta) + assert mock_storage.store.call_count == 2 + # JSONL should be rolled back via delete + mock_storage.delete.assert_called_once() + + def test_baseline_mode_stored_in_meta(self): + """upload_transcript with mode='baseline' stores mode in companion meta.json.""" + import asyncio import json + from unittest.mock import AsyncMock, patch + + from .transcript import upload_transcript + + mock_storage = AsyncMock() + content = b'{"type":"assistant"}\n' + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + asyncio.run( + upload_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000098", + content=content, + message_count=4, + mode="baseline", + ) + ) + + meta_call = next( + c + for c in mock_storage.store.call_args_list + if c.kwargs.get("filename", "").endswith(".meta.json") + ) + meta_content = json.loads(meta_call.kwargs["content"]) + assert meta_content["mode"] == "baseline" + assert meta_content["message_count"] == 4 + + def test_strips_session_before_upload_and_writes_back(self): + """strip_for_upload removes progress entries and returns smaller content.""" + import json + + from .transcript import strip_for_upload progress_entry = { "type": "progress", @@ -968,64 +911,22 @@ class TestUploadCliSession: + json.dumps(asst_entry) + "\n" ) - raw_bytes = raw_content.encode("utf-8") - session_file.write_bytes(raw_bytes) - mock_storage = AsyncMock() + stripped = strip_for_upload(raw_content) - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - # Upload should have been called with stripped bytes (no progress entry). - mock_storage.store.assert_called_once() - stored_content: bytes = mock_storage.store.call_args.kwargs["content"] - stored_lines = stored_content.decode("utf-8").strip().split("\n") + stored_lines = stripped.strip().split("\n") stored_types = [json.loads(line).get("type") for line in stored_lines] assert "progress" not in stored_types assert "user" in stored_types assert "assistant" in stored_types - # Stripped bytes should be smaller than raw. - assert len(stored_content) < len(raw_bytes) - # File on disk should also be the stripped version. - disk_content = session_file.read_bytes() - assert disk_content == stored_content + assert len(stripped.encode()) < len(raw_content.encode()) - def test_strips_stale_thinking_blocks_before_upload(self, tmp_path): - """Thinking blocks in non-last assistant turns are stripped to reduce size.""" - import asyncio + def test_strips_stale_thinking_blocks_before_upload(self): + """strip_for_upload removes thinking blocks from non-last assistant turns.""" import json - import os - import re - from unittest.mock import AsyncMock, patch - from .transcript import _sanitize_id, upload_cli_session + from .transcript import strip_for_upload - projects_base = str(tmp_path) - session_id = "12345678-0000-0000-0000-000000000011" - sdk_cwd = str(tmp_path) - - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(sdk_cwd)) - session_dir = tmp_path / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - session_file = session_dir / f"{_sanitize_id(session_id)}.jsonl" - - # Two turns: first assistant has thinking block (stale), second doesn't. u1 = { "type": "user", "uuid": "u1", @@ -1070,32 +971,10 @@ class TestUploadCliSession: + json.dumps(a2_no_thinking) + "\n" ) - raw_bytes = raw_content.encode("utf-8") - session_file.write_bytes(raw_bytes) - mock_storage = AsyncMock() + stripped = strip_for_upload(raw_content) - with ( - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - ): - asyncio.run( - upload_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - stored_content: bytes = mock_storage.store.call_args.kwargs["content"] - stored_lines = stored_content.decode("utf-8").strip().split("\n") + stored_lines = stripped.strip().split("\n") # a1 should have its thinking block stripped (it's not the last assistant turn). a1_stored = json.loads(stored_lines[1]) @@ -1111,20 +990,20 @@ class TestUploadCliSession: a2_stored = json.loads(stored_lines[3]) assert a2_stored["message"]["content"] == [{"type": "text", "text": "answer2"}] - # Stripped bytes smaller than raw. - assert len(stored_content) < len(raw_bytes) - class TestRestoreCliSession: - def test_returns_false_when_file_not_found_in_storage(self): - """Returns False (graceful degradation) when the session is missing.""" + def test_returns_none_when_file_not_found_in_storage(self): + """Returns None (graceful degradation) when the session is missing.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = FileNotFoundError("not found") + mock_storage.retrieve.side_effect = [ + FileNotFoundError("no session"), + FileNotFoundError("no meta"), + ] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -1132,144 +1011,26 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd="/tmp/copilot-test", ) ) - assert result is False + assert result is None - def test_returns_false_when_restore_path_outside_projects_base(self, tmp_path): - """Path traversal guard: rejects restoration outside the projects base.""" + def test_returns_transcript_download_on_success_no_meta(self): + """Happy path with no meta.json: returns TranscriptDownload with message_count=0.""" import asyncio from unittest.mock import AsyncMock, patch - from .transcript import restore_cli_session + from .transcript import download_transcript - mock_storage = AsyncMock() - mock_storage.retrieve.return_value = b'{"type":"assistant"}\n' - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=str(tmp_path), - ), - # Return a path genuinely outside tmp_path so the boundary guard fires. - patch( - "backend.copilot.transcript._cli_session_path", - return_value="/outside/escaped/session.jsonl", - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id="12345678-0000-0000-0000-000000000000", - sdk_cwd=str(tmp_path), - ) - ) - - assert result is False - - def test_returns_true_when_local_file_already_exists(self, tmp_path): - """Same-pod reuse: if local file exists, skip storage download and return True.""" - import asyncio - import os - import re - from pathlib import Path - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - session_id = "12345678-0000-0000-0000-000000000099" - sdk_cwd = str(tmp_path) - - # Pre-create the local session file (simulates previous turn on same pod) - projects_base = os.path.realpath(str(tmp_path)) - encoded_cwd = re.sub(r"[^a-zA-Z0-9]", "-", projects_base) - session_dir = Path(projects_base) / encoded_cwd - session_dir.mkdir(parents=True, exist_ok=True) - existing_content = b'{"type":"user"}\n{"type":"assistant"}\n' - (session_dir / f"{session_id}.jsonl").write_bytes(existing_content) - - mock_storage = AsyncMock() - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - # Storage should NOT have been accessed (local file was used as-is) - mock_storage.retrieve.assert_not_called() - # Local file should be unchanged - assert (session_dir / f"{session_id}.jsonl").read_bytes() == existing_content - - def test_returns_true_on_success(self, tmp_path): - """Happy path: storage has the session → file written → returns True.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - projects_base = str(tmp_path) - sdk_cwd = str(tmp_path) session_id = "12345678-0000-0000-0000-000000000003" content = b'{"type":"assistant"}\n' mock_storage = AsyncMock() - mock_storage.retrieve.return_value = content - - with ( - patch( - "backend.copilot.transcript.get_workspace_storage", - new_callable=AsyncMock, - return_value=mock_storage, - ), - patch( - "backend.copilot.transcript._projects_base", - return_value=projects_base, - ), - ): - result = asyncio.run( - restore_cli_session( - user_id="user-1", - session_id=session_id, - sdk_cwd=sdk_cwd, - ) - ) - - assert result is True - - def test_returns_false_on_download_exception(self): - """Non-FileNotFoundError during retrieve logs warning and returns False.""" - import asyncio - from unittest.mock import AsyncMock, patch - - from .transcript import restore_cli_session - - mock_storage = AsyncMock() - mock_storage.retrieve.side_effect = RuntimeError("network error") + mock_storage.retrieve.side_effect = [content, FileNotFoundError("no meta")] with patch( "backend.copilot.transcript.get_workspace_storage", @@ -1277,11 +1038,411 @@ class TestRestoreCliSession: return_value=mock_storage, ): result = asyncio.run( - restore_cli_session( + download_transcript( user_id="user-1", - session_id="12345678-0000-0000-0000-000000000004", - sdk_cwd="/tmp/copilot-test", + session_id=session_id, ) ) - assert result is False + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_returns_transcript_download_with_message_count_from_meta(self): + """When meta.json is present, message_count and mode are read from it.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + session_id = "12345678-0000-0000-0000-000000000005" + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 7, "mode": "sdk", "uploaded_at": 1234567.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id=session_id, + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 7 + assert result.mode == "sdk" + + def test_returns_none_on_download_exception(self): + """Non-FileNotFoundError during retrieve logs warning and returns None.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [ + RuntimeError("network error"), + FileNotFoundError("no meta"), + ] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000004", + ) + ) + + assert result is None + + def test_baseline_mode_in_meta_returned(self): + """When meta.json contains mode='baseline', result.mode is 'baseline'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps( + {"message_count": 3, "mode": "baseline", "uploaded_at": 0.0} + ).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000020", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "baseline" + assert result.message_count == 3 + + def test_invalid_mode_in_meta_defaults_to_sdk(self): + """Unknown mode value in meta.json falls back to 'sdk'.""" + import asyncio + import json + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + meta_bytes = json.dumps({"message_count": 2, "mode": "unknown_mode"}).encode() + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, meta_bytes] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000021", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.mode == "sdk" + + def test_invalid_utf8_meta_uses_defaults(self): + """Meta bytes that fail UTF-8 decode fall back to message_count=0, mode='sdk'.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + bad_meta = b"\xff\xfe" + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, bad_meta] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000022", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.message_count == 0 + assert result.mode == "sdk" + + def test_meta_fetch_exception_uses_defaults(self): + """Non-FileNotFoundError on meta fetch still returns content with defaults.""" + import asyncio + from unittest.mock import AsyncMock, patch + + from .transcript import download_transcript + + content = b'{"type":"assistant"}\n' + + mock_storage = AsyncMock() + mock_storage.retrieve.side_effect = [content, RuntimeError("meta unavailable")] + + with patch( + "backend.copilot.transcript.get_workspace_storage", + new_callable=AsyncMock, + return_value=mock_storage, + ): + result = asyncio.run( + download_transcript( + user_id="user-1", + session_id="12345678-0000-0000-0000-000000000023", + ) + ) + + assert isinstance(result, TranscriptDownload) + assert result.content == content + assert result.message_count == 0 + assert result.mode == "sdk" + + +# --------------------------------------------------------------------------- +# detect_gap +# --------------------------------------------------------------------------- + + +def _msgs(*roles: str): + """Build a list of ChatMessage objects with the given roles.""" + from .model import ChatMessage + + return [ChatMessage(role=r, content=f"{r}-{i}") for i, r in enumerate(roles)] + + +class TestDetectGap: + """``detect_gap`` returns messages between transcript watermark and current turn.""" + + def _dl(self, message_count: int) -> TranscriptDownload: + return TranscriptDownload(content=b"", message_count=message_count, mode="sdk") + + def test_zero_watermark_returns_empty(self): + """message_count=0 means no watermark — skip gap detection.""" + dl = self._dl(0) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_covers_all_prefix_returns_empty(self): + """Transcript already covers all messages up to the current user turn.""" + # session: [user, assistant, user(current)] — wm=2 means covers up to assistant + dl = self._dl(2) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_watermark_exceeds_session_returns_empty(self): + """Watermark ahead of session count (race / over-count) → no gap.""" + dl = self._dl(10) + messages = _msgs("user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_misaligned_watermark_not_on_assistant_returns_empty(self): + """Watermark at a user-role position is misaligned — skip gap.""" + # wm=1: position 0 is 'user', not 'assistant' → skip + dl = self._dl(1) + messages = _msgs("user", "assistant", "user", "assistant", "user") + assert detect_gap(dl, messages) == [] + + def test_returns_gap_messages(self): + """Watermark behind session — gap messages returned (excluding current turn).""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=2: transcript covers [0,1]; gap = [user2, assistant3] + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 2 + assert gap[0].role == "user" + assert gap[1].role == "assistant" + + def test_excludes_current_user_turn(self): + """The last message (current user turn) is never included in the gap.""" + # wm=2, session has 4 msgs: gap = [msg2] only (msg3 is current turn → excluded) + dl = self._dl(2) + messages = _msgs("user", "assistant", "user", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "user" + + def test_single_gap_message(self): + """One message between watermark and current turn.""" + # session: [user0, assistant1, user2, assistant3, user4(current)] + # wm=3: position 2 is 'user' → misaligned, returns [] + # use wm=4: but 4 >= total-1=4 → also empty + # wm=3 with session [u, a, u, a, u, a, u(current)]: position 2 is 'user' → empty + # Valid case: wm=2 has 3 messages (assistant at 1), wm=4 with [u,a,u,a,u,a,u]: + # let's use wm=4 with 7 messages: wm=4 >= total-1=6? no, 4<6. pos[3]=assistant → gap=[msg4,msg5] + # simpler: wm=2, [u0,a1,a2,u3(current)] — pos[1]=assistant, gap=[a2] only + dl = self._dl(2) + messages = _msgs("user", "assistant", "assistant", "user") + gap = detect_gap(dl, messages) + assert len(gap) == 1 + assert gap[0].role == "assistant" + + +# --------------------------------------------------------------------------- +# extract_context_messages +# --------------------------------------------------------------------------- + + +def _make_valid_transcript(*roles: str) -> str: + """Build a minimal valid JSONL transcript with the given message roles.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + lines = [] + parent = "" + for i, role in enumerate(roles): + uid = f"uid-{i}" + entry: dict = { + "type": role, + "uuid": uid, + "parentUuid": parent, + "message": { + "role": role, + "content": f"{role} content {i}", + }, + } + if role == "assistant": + entry["message"]["id"] = f"msg_{i}" + entry["message"]["model"] = "test-model" + entry["message"]["type"] = "message" + entry["message"]["stop_reason"] = STOP_REASON_END_TURN + entry["message"]["content"] = [ + {"type": "text", "text": f"assistant content {i}"} + ] + lines.append(stdlib_json.dumps(entry)) + parent = uid + return "\n".join(lines) + "\n" + + +class TestExtractContextMessages: + """``extract_context_messages`` returns the shared context primitive.""" + + def test_none_download_returns_prior(self): + """No download → falls back to all session messages except current turn.""" + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(None, messages) + assert result == messages[:-1] + assert len(result) == 2 + + def test_empty_content_download_returns_prior(self): + """Empty bytes content → falls back to all prior messages.""" + dl = TranscriptDownload(content=b"", message_count=2, mode="sdk") + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + assert result == messages[:-1] + + def test_valid_transcript_no_gap_returns_transcript_messages(self): + """Transcript covers all prior turns → only transcript messages returned.""" + # Transcript: [user, assistant] — 2 messages + # Session: [user, assistant, user(current)] — watermark=2 covers prefix + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Transcript has 2 messages (user + assistant) and no gap + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + + def test_valid_transcript_with_gap_returns_transcript_plus_gap(self): + """Transcript is stale → gap messages appended after transcript content.""" + # Transcript: [user, assistant] — watermark=2 + # Session: [user, assistant, user, assistant, user(current)] + # Gap: [user(2), assistant(3)] — positions 2 and 3 + transcript_content = _make_valid_transcript("user", "assistant") + dl = TranscriptDownload( + content=transcript_content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user", "assistant", "user") + result = extract_context_messages(dl, messages) + # 2 transcript messages + 2 gap messages = 4 + assert len(result) == 4 + assert result[0].role == "user" # transcript user + assert result[1].role == "assistant" # transcript assistant + assert result[2].role == "user" # gap user + assert result[3].role == "assistant" # gap assistant + + def test_compact_summary_entries_preserved(self): + """``isCompactSummary=True`` entries survive ``_transcript_to_messages``.""" + import json as stdlib_json + + from .transcript import STOP_REASON_END_TURN + + # Build a transcript where one entry is a compaction summary. + # isCompactSummary=True entries have type in STRIPPABLE_TYPES but are kept. + compact_entry = stdlib_json.dumps( + { + "type": "summary", + "uuid": "uid-compact", + "parentUuid": "", + "isCompactSummary": True, + "message": { + "role": "user", + "content": "COMPACT_SUMMARY_CONTENT", + }, + } + ) + assistant_entry = stdlib_json.dumps( + { + "type": "assistant", + "uuid": "uid-1", + "parentUuid": "uid-compact", + "message": { + "role": "assistant", + "id": "msg_1", + "model": "test", + "type": "message", + "stop_reason": STOP_REASON_END_TURN, + "content": [{"type": "text", "text": "response after compact"}], + }, + } + ) + content = compact_entry + "\n" + assistant_entry + "\n" + dl = TranscriptDownload( + content=content.encode("utf-8"), message_count=2, mode="sdk" + ) + messages = _msgs("user", "assistant", "user") + result = extract_context_messages(dl, messages) + # Both the compact summary and the assistant response are present + assert len(result) == 2 + roles = [m.role for m in result] + assert "user" in roles # compact summary has role=user + assert "assistant" in roles + # The compact summary content is preserved + compact_msgs = [m for m in result if m.role == "user"] + assert any("COMPACT_SUMMARY_CONTENT" in (m.content or "") for m in compact_msgs) diff --git a/autogpt_platform/backend/backend/data/block_cost_config.py b/autogpt_platform/backend/backend/data/block_cost_config.py index 1753d5e65e..a4a9a8ef55 100644 --- a/autogpt_platform/backend/backend/data/block_cost_config.py +++ b/autogpt_platform/backend/backend/data/block_cost_config.py @@ -143,6 +143,8 @@ MODEL_COST: dict[LlmModel, int] = { LlmModel.GROK_4: 9, LlmModel.GROK_4_FAST: 1, LlmModel.GROK_4_1_FAST: 1, + LlmModel.GROK_4_20: 5, + LlmModel.GROK_4_20_MULTI_AGENT: 5, LlmModel.GROK_CODE_FAST_1: 1, LlmModel.KIMI_K2: 1, LlmModel.QWEN3_235B_A22B_THINKING: 1, diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index 24b5aae80d..e97578d5cc 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -1,10 +1,13 @@ +import asyncio import logging +import time from abc import ABC, abstractmethod from collections import defaultdict from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, cast import stripe +from fastapi.concurrency import run_in_threadpool from prisma.enums import ( CreditRefundRequestStatus, CreditTransactionType, @@ -31,6 +34,7 @@ from backend.data.model import ( from backend.data.notifications import NotificationEventModel, RefundRequestData from backend.data.user import get_user_by_id, get_user_email_by_id from backend.notifications.notifications import queue_notification_async +from backend.util.cache import cached from backend.util.exceptions import InsufficientBalanceError from backend.util.feature_flag import Flag, get_feature_flag_value, is_feature_enabled from backend.util.json import SafeJson, dumps @@ -432,7 +436,7 @@ class UserCreditBase(ABC): current_balance, _ = await self._get_credits(user_id) if current_balance >= ceiling_balance: raise ValueError( - f"You already have enough balance of ${current_balance/100}, top-up is not required when you already have at least ${ceiling_balance/100}" + f"You already have enough balance of ${current_balance / 100}, top-up is not required when you already have at least ${ceiling_balance / 100}" ) # Single unified atomic operation for all transaction types using UserBalance @@ -571,7 +575,7 @@ class UserCreditBase(ABC): if amount < 0 and fail_insufficient_credits: current_balance, _ = await self._get_credits(user_id) raise InsufficientBalanceError( - message=f"Insufficient balance of ${current_balance/100}, where this will cost ${abs(amount)/100}", + message=f"Insufficient balance of ${current_balance / 100}, where this will cost ${abs(amount) / 100}", user_id=user_id, balance=current_balance, amount=amount, @@ -582,7 +586,6 @@ class UserCreditBase(ABC): class UserCredit(UserCreditBase): - async def _send_refund_notification( self, notification_request: RefundRequestData, @@ -734,7 +737,7 @@ class UserCredit(UserCreditBase): ) if request.amount <= 0 or request.amount > transaction.amount: raise AssertionError( - f"Invalid amount to deduct ${request.amount/100} from ${transaction.amount/100} top-up" + f"Invalid amount to deduct ${request.amount / 100} from ${transaction.amount / 100} top-up" ) balance, _ = await self._add_transaction( @@ -788,12 +791,12 @@ class UserCredit(UserCreditBase): # If the user has enough balance, just let them win the dispute. if balance - amount >= settings.config.refund_credit_tolerance_threshold: - logger.warning(f"Accepting dispute from {user_id} for ${amount/100}") + logger.warning(f"Accepting dispute from {user_id} for ${amount / 100}") dispute.close() return logger.warning( - f"Adding extra info for dispute from {user_id} for ${amount/100}" + f"Adding extra info for dispute from {user_id} for ${amount / 100}" ) # Retrieve recent transaction history to support our evidence. # This provides a concise timeline that shows service usage and proper credit application. @@ -1237,14 +1240,23 @@ async def get_stripe_customer_id(user_id: str) -> str: if user.stripe_customer_id: return user.stripe_customer_id - customer = stripe.Customer.create( + # Race protection: two concurrent calls (e.g. user double-clicks "Upgrade", + # or any retried request) would each pass the check above and create their + # own Stripe Customer, leaving an orphaned billable customer in Stripe. + # Pass an idempotency_key so Stripe collapses concurrent + retried calls + # into the same Customer object server-side. The 24h Stripe idempotency + # window comfortably covers any realistic in-flight retry scenario. + customer = await run_in_threadpool( + stripe.Customer.create, name=user.name or "", email=user.email, metadata={"user_id": user_id}, + idempotency_key=f"customer-create-{user_id}", ) await User.prisma().update( where={"id": user_id}, data={"stripeCustomerId": customer.id} ) + get_user_by_id.cache_delete(user_id) return customer.id @@ -1263,23 +1275,203 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None: data={"subscriptionTier": tier}, ) get_user_by_id.cache_delete(user_id) + # Also invalidate the rate-limit tier cache so CoPilot picks up the new + # tier immediately rather than waiting up to 5 minutes for the TTL to expire. + from backend.copilot.rate_limit import get_user_tier # local import avoids circular + + get_user_tier.cache_delete(user_id) # type: ignore[attr-defined] -async def cancel_stripe_subscription(user_id: str) -> None: - """Cancel all active Stripe subscriptions for a user (called on downgrade to FREE).""" - customer_id = await get_stripe_customer_id(user_id) - subscriptions = stripe.Subscription.list( - customer=customer_id, status="active", limit=10 - ) - for sub in subscriptions.auto_paging_iter(): - try: - stripe.Subscription.cancel(sub["id"]) - except stripe.StripeError: - logger.warning( - "cancel_stripe_subscription: failed to cancel sub %s for user %s", - sub["id"], - user_id, +async def _cancel_customer_subscriptions( + customer_id: str, + exclude_sub_id: str | None = None, + at_period_end: bool = False, +) -> int: + """Cancel all billable Stripe subscriptions for a customer, optionally excluding one. + + Cancels both ``active`` and ``trialing`` subscriptions, since trialing subs will + start billing once the trial ends and must be cleaned up on downgrade/upgrade to + avoid double-charging or charging users who intended to cancel. + + When ``at_period_end=True``, schedules cancellation at the end of the current + billing period instead of cancelling immediately — the user keeps their tier + until the period ends, then ``customer.subscription.deleted`` fires and the + webhook downgrades them to FREE. + + Wraps every synchronous Stripe SDK call with run_in_threadpool so the async event + loop is never blocked. Raises stripe.StripeError on list/cancel failure so callers + that need strict consistency can react; cleanup callers can catch and log instead. + + Returns the number of subscriptions cancelled/scheduled for cancellation. + """ + # Query active and trialing separately; Stripe's list API accepts a single status + # filter at a time (no OR), and we explicitly want to skip canceled/incomplete/ + # past_due subs rather than filter them out client-side via status="all". + seen_ids: set[str] = set() + for status in ("active", "trialing"): + subscriptions = await run_in_threadpool( + stripe.Subscription.list, customer=customer_id, status=status, limit=10 + ) + # Iterate only the first page (up to 10); avoid auto_paging_iter which would + # trigger additional sync HTTP calls inside the event loop. + if subscriptions.has_more: + logger.error( + "_cancel_customer_subscriptions: customer %s has more than 10 %s" + " subscriptions — only the first page was processed; remaining" + " subscriptions were NOT cancelled", + customer_id, + status, ) + for sub in subscriptions.data: + sub_id = sub["id"] + if exclude_sub_id and sub_id == exclude_sub_id: + continue + if sub_id in seen_ids: + continue + seen_ids.add(sub_id) + if at_period_end: + await run_in_threadpool( + stripe.Subscription.modify, sub_id, cancel_at_period_end=True + ) + else: + await run_in_threadpool(stripe.Subscription.cancel, sub_id) + return len(seen_ids) + + +async def cancel_stripe_subscription(user_id: str) -> bool: + """Schedule cancellation of all active/trialing Stripe subscriptions at period end. + + The subscription stays active until the end of the billing period so the user + keeps their tier for the time they already paid for. The ``customer.subscription.deleted`` + webhook fires at period end and downgrades the DB tier to FREE. + + Returns True if at least one subscription was found and scheduled for cancellation, + False if the customer had no active/trialing subscriptions (e.g., admin-granted tier + with no associated Stripe subscription). When False, the caller should update the + DB tier directly since no webhook will fire to do it. + + Raises stripe.StripeError if any modification fails, so the caller can avoid + updating the DB tier when Stripe is inconsistent. + """ + # Guard: only proceed if the user already has a Stripe customer ID. Calling + # get_stripe_customer_id for a user who has never had a paid subscription would + # create an orphaned, potentially-billable Stripe Customer object — we avoid that + # by returning False early so the caller can downgrade the DB tier directly. + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + return False + + customer_id = user.stripe_customer_id + try: + cancelled_count = await _cancel_customer_subscriptions( + customer_id, at_period_end=True + ) + return cancelled_count > 0 + except stripe.StripeError: + logger.warning( + "cancel_stripe_subscription: Stripe error while cancelling subs for user %s", + user_id, + ) + raise + + +async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> int: + """Return the prorated credit (in cents) the user would receive if they upgraded now. + + Fetches the user's active Stripe subscription to determine how many seconds + remain in the current billing period, then calculates the unused portion of + the monthly cost. Returns 0 for FREE/ENTERPRISE users or when no active sub + is found. + """ + if monthly_cost_cents <= 0: + return 0 + # Guard: only query Stripe if the user already has a customer ID. Admin-granted + # paid tiers have no Stripe record; calling get_stripe_customer_id would create an + # orphaned customer on every billing-page load for those users. + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + return 0 + try: + customer_id = user.stripe_customer_id + subscriptions = await run_in_threadpool( + stripe.Subscription.list, customer=customer_id, status="active", limit=1 + ) + if not subscriptions.data: + return 0 + sub = subscriptions.data[0] + period_start: int = sub["current_period_start"] + period_end: int = sub["current_period_end"] + now = int(time.time()) + total_seconds = period_end - period_start + remaining_seconds = max(period_end - now, 0) + if total_seconds <= 0: + return 0 + return int(monthly_cost_cents * remaining_seconds / total_seconds) + except Exception: + logger.warning( + "get_proration_credit_cents: failed to compute proration for user %s", + user_id, + ) + return 0 + + +async def modify_stripe_subscription_for_tier( + user_id: str, tier: SubscriptionTier +) -> bool: + """Modify an existing Stripe subscription to a new paid tier using proration. + + For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing + subscription is preferable to cancelling + creating a new one via Checkout: + Stripe handles proration automatically, crediting unused time on the old plan + and charging the pro-rated amount for the new plan in the same billing cycle. + + Returns: + True — a subscription was found and modified successfully. + False — no active/trialing subscription exists (e.g. admin-granted tier or + first-time paid signup); caller should fall back to Checkout. + + Raises stripe.StripeError on API failures so callers can propagate a 502. + Raises ValueError when no Stripe price ID is configured for the tier. + """ + price_id = await get_subscription_price_id(tier) + if not price_id: + raise ValueError(f"No Stripe price ID configured for tier {tier}") + + # Guard: only proceed if the user already has a Stripe customer ID. Calling + # get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier) + # would create an orphaned customer object if the subsequent Subscription.list call + # fails. Return False early so the API layer falls back to Checkout instead. + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + return False + + customer_id = user.stripe_customer_id + for status in ("active", "trialing"): + subscriptions = await run_in_threadpool( + stripe.Subscription.list, customer=customer_id, status=status, limit=1 + ) + if not subscriptions.data: + continue + sub = subscriptions.data[0] + sub_id = sub["id"] + items = sub.get("items", {}).get("data", []) + if not items: + continue + item_id = items[0]["id"] + await run_in_threadpool( + stripe.Subscription.modify, + sub_id, + items=[{"id": item_id, "price": price_id}], + proration_behavior="create_prorations", + ) + logger.info( + "modify_stripe_subscription_for_tier: modified sub %s for user %s → %s", + sub_id, + user_id, + tier, + ) + return True + return False async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: @@ -1291,8 +1483,19 @@ async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: return AutoTopUpConfig.model_validate(user.top_up_config) +@cached(ttl_seconds=60, maxsize=8, cache_none=False) async def get_subscription_price_id(tier: SubscriptionTier) -> str | None: - """Return Stripe Price ID for a tier from LaunchDarkly. None = not configured.""" + """Return Stripe Price ID for a tier from LaunchDarkly, cached for 60 seconds. + + Price IDs are LaunchDarkly flag values that change only at deploy time. + Caching for 60 seconds avoids hitting the LD SDK on every webhook delivery + and every GET /credits/subscription page load (called 2x per request). + + ``cache_none=False`` prevents a transient LD failure from caching ``None`` + and blocking subscription upgrades for the full 60-second TTL window. + A tier with no configured flag (FREE, ENTERPRISE) returns ``None`` from an + O(1) dict lookup before hitting LD, so the extra LD call is never made. + """ flag_map = { SubscriptionTier.PRO: Flag.STRIPE_PRICE_PRO, SubscriptionTier.BUSINESS: Flag.STRIPE_PRICE_BUSINESS, @@ -1300,7 +1503,7 @@ async def get_subscription_price_id(tier: SubscriptionTier) -> str | None: flag = flag_map.get(tier) if flag is None: return None - price_id = await get_feature_flag_value(flag.value, user_id="", default="") + price_id = await get_feature_flag_value(flag.value, user_id="system", default="") return price_id if isinstance(price_id, str) and price_id else None @@ -1315,7 +1518,8 @@ async def create_subscription_checkout( if not price_id: raise ValueError(f"Subscription not available for tier {tier.value}") customer_id = await get_stripe_customer_id(user_id) - session = stripe.checkout.Session.create( + session = await run_in_threadpool( + stripe.checkout.Session.create, customer=customer_id, mode="subscription", line_items=[{"price": price_id, "quantity": 1}], @@ -1323,26 +1527,111 @@ async def create_subscription_checkout( cancel_url=cancel_url, subscription_data={"metadata": {"user_id": user_id, "tier": tier.value}}, ) - return session.url or "" + if not session.url: + # An empty checkout URL for a paid upgrade is always an error; surfacing it + # as ValueError means the API handler returns 422 instead of silently + # redirecting the client to an empty URL. + raise ValueError("Stripe did not return a checkout session URL") + return session.url + + +async def _cleanup_stale_subscriptions(customer_id: str, new_sub_id: str) -> None: + """Best-effort cancel of any active subs for the customer other than new_sub_id. + + Called from the webhook handler after a new subscription becomes active. Failures + are logged but not raised so a transient Stripe error doesn't crash the webhook — + a periodic reconciliation job is the intended backstop for persistent drift. + + NOTE: until that reconcile job lands, a failure here means the user is silently + billed for two simultaneous subscriptions. The error log below is intentionally + `logger.exception` so it surfaces in Sentry with the customer/sub IDs needed to + manually reconcile, and the metric `stripe_stale_subscription_cleanup_failed` + is bumped so on-call can alert on persistent drift. + TODO(#stripe-reconcile-job): replace this best-effort cleanup with a periodic + reconciliation job that queries Stripe for customers with >1 active sub. + """ + try: + await _cancel_customer_subscriptions(customer_id, exclude_sub_id=new_sub_id) + except stripe.StripeError: + # Use exception() (not warning) so this surfaces as an error in Sentry — + # any failure here means a paid-to-paid upgrade may have left the user + # with two simultaneous active subscriptions. + logger.exception( + "stripe_stale_subscription_cleanup_failed: customer=%s new_sub=%s —" + " user may be billed for two simultaneous subscriptions; manual" + " reconciliation required", + customer_id, + new_sub_id, + ) async def sync_subscription_from_stripe(stripe_subscription: dict) -> None: - """Update User.subscriptionTier from a Stripe subscription object.""" - customer_id = stripe_subscription["customer"] + """Update User.subscriptionTier from a Stripe subscription object. + + Expected shape of stripe_subscription (subset of Stripe's Subscription object): + customer: str — Stripe customer ID + status: str — "active" | "trialing" | "canceled" | ... + id: str — Stripe subscription ID + items.data[].price.id: str — Stripe price ID identifying the tier + """ + customer_id = stripe_subscription.get("customer") + if not customer_id: + logger.warning( + "sync_subscription_from_stripe: missing 'customer' field in event, " + "skipping (keys: %s)", + list(stripe_subscription.keys()), + ) + return user = await User.prisma().find_first(where={"stripeCustomerId": customer_id}) if not user: logger.warning( "sync_subscription_from_stripe: no user for customer %s", customer_id ) return + # Cross-check: if the subscription carries a metadata.user_id (set during + # Checkout Session creation), verify it matches the user we found via + # stripeCustomerId. A mismatch indicates a customer↔user mapping + # inconsistency — updating the wrong user's tier would be a data-corruption + # bug, so we log loudly and bail out. Absence of metadata.user_id (e.g. + # subscriptions created outside the Checkout flow) is not an error — we + # simply skip the check and proceed with the customer-ID-based lookup. + metadata = stripe_subscription.get("metadata") or {} + metadata_user_id = metadata.get("user_id") if isinstance(metadata, dict) else None + if metadata_user_id and metadata_user_id != user.id: + logger.error( + "sync_subscription_from_stripe: metadata.user_id=%s does not match" + " user.id=%s found via stripeCustomerId=%s — refusing to update tier" + " to avoid corrupting the wrong user's subscription state", + metadata_user_id, + user.id, + customer_id, + ) + return + # ENTERPRISE tiers are admin-managed. Never let a Stripe webhook flip an + # ENTERPRISE user to a different tier — if a user on ENTERPRISE somehow has + # a self-service Stripe sub, it's a data-consistency issue for an operator, + # not something the webhook should automatically "fix". + current_tier = user.subscriptionTier or SubscriptionTier.FREE + if current_tier == SubscriptionTier.ENTERPRISE: + logger.warning( + "sync_subscription_from_stripe: refusing to overwrite ENTERPRISE tier" + " for user %s (customer %s); event status=%s", + user.id, + customer_id, + stripe_subscription.get("status", ""), + ) + return status = stripe_subscription.get("status", "") + new_sub_id = stripe_subscription.get("id", "") if status in ("active", "trialing"): price_id = "" items = stripe_subscription.get("items", {}).get("data", []) if items: price_id = items[0].get("price", {}).get("id", "") - pro_price = await get_subscription_price_id(SubscriptionTier.PRO) - biz_price = await get_subscription_price_id(SubscriptionTier.BUSINESS) + pro_price, biz_price = await asyncio.gather( + get_subscription_price_id(SubscriptionTier.PRO), + get_subscription_price_id(SubscriptionTier.BUSINESS), + ) if price_id and pro_price and price_id == pro_price: tier = SubscriptionTier.PRO elif price_id and biz_price and price_id == biz_price: @@ -1359,10 +1648,206 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None: ) return else: + # A subscription was cancelled or ended. DO NOT unconditionally downgrade + # to FREE — Stripe does not guarantee webhook delivery order, so a + # `customer.subscription.deleted` for the OLD sub can arrive after we've + # already processed `customer.subscription.created` for a new paid sub. + # Ask Stripe whether any OTHER active/trialing subs exist for this + # customer; if they do, keep the user's current tier (the other sub's + # own event will/has already set the correct tier). + try: + other_subs_active, other_subs_trialing = await asyncio.gather( + run_in_threadpool( + stripe.Subscription.list, + customer=customer_id, + status="active", + limit=10, + ), + run_in_threadpool( + stripe.Subscription.list, + customer=customer_id, + status="trialing", + limit=10, + ), + ) + except stripe.StripeError: + logger.warning( + "sync_subscription_from_stripe: could not verify other active" + " subs for customer %s on cancel event %s; preserving current" + " tier to avoid an unsafe downgrade", + customer_id, + new_sub_id, + ) + return + # Filter out the cancelled subscription to check if other active subs + # exist. When new_sub_id is empty (malformed event with no 'id' field), + # we cannot safely exclude any sub — preserve current tier to avoid + # an unsafe downgrade on a malformed webhook payload. + if not new_sub_id: + logger.warning( + "sync_subscription_from_stripe: cancel event missing 'id' field" + " for customer %s; preserving current tier", + customer_id, + ) + return + other_active_ids = {sub["id"] for sub in other_subs_active.data} - {new_sub_id} + other_trialing_ids = {sub["id"] for sub in other_subs_trialing.data} - { + new_sub_id + } + still_has_active_sub = bool(other_active_ids or other_trialing_ids) + if still_has_active_sub: + logger.info( + "sync_subscription_from_stripe: sub %s cancelled but customer %s" + " still has another active sub; keeping tier %s", + new_sub_id, + customer_id, + current_tier.value, + ) + return tier = SubscriptionTier.FREE + # Idempotency: Stripe retries webhooks on delivery failure, and several event + # types map to the same final tier. Skip the DB write + cache invalidation + # when the tier is already correct to avoid redundant writes on replay. + if current_tier == tier: + return + # When a new subscription becomes active (e.g. paid-to-paid tier upgrade + # via a fresh Checkout Session), cancel any OTHER active subscriptions for + # the same customer so the user isn't billed twice. We do this in the + # webhook rather than the API handler so that abandoning the checkout + # doesn't leave the user without a subscription. + # IMPORTANT: this runs AFTER the idempotency check above so that webhook + # replays for an already-applied event do NOT trigger another cleanup round + # (which could otherwise cancel a legitimately new subscription the user + # signed up for between the original event and its replay). + if status in ("active", "trialing") and new_sub_id: + # NOTE: paid-to-paid upgrade race (e.g. PRO → BUSINESS): + # _cleanup_stale_subscriptions cancels the old PRO sub before + # set_subscription_tier writes BUSINESS to the DB. If Stripe delivers + # the PRO `customer.subscription.deleted` event concurrently and it + # processes after the PRO cancel but before set_subscription_tier + # commits, the user could momentarily appear as FREE in the DB. + # This window is very short in practice (two sequential awaits), + # but is a known limitation of the current webhook-driven approach. + # A future improvement would be to write the new tier first, then + # cancel the old sub. + await _cleanup_stale_subscriptions(customer_id, new_sub_id) await set_subscription_tier(user.id, tier) +async def handle_subscription_payment_failure(invoice: dict) -> None: + """Handle a failed Stripe subscription payment. + + Tries to cover the invoice amount from the user's credit balance. + + - Balance sufficient → deduct from balance, then pay the Stripe invoice so + Stripe stops retrying it. The sub stays intact and the user keeps their tier. + - Balance insufficient → cancel Stripe sub immediately, downgrade to FREE. + Cancelling here avoids further Stripe retries on an invoice we cannot cover. + """ + customer_id = invoice.get("customer") + if not customer_id: + logger.warning( + "handle_subscription_payment_failure: missing customer in invoice; skipping" + ) + return + + user = await User.prisma().find_first(where={"stripeCustomerId": customer_id}) + if not user: + logger.warning( + "handle_subscription_payment_failure: no user found for customer %s", + customer_id, + ) + return + + current_tier = user.subscriptionTier or SubscriptionTier.FREE + if current_tier == SubscriptionTier.ENTERPRISE: + logger.warning( + "handle_subscription_payment_failure: skipping ENTERPRISE user %s" + " (customer %s) — tier is admin-managed", + user.id, + customer_id, + ) + return + + amount_due: int = invoice.get("amount_due", 0) + sub_id: str = invoice.get("subscription", "") + invoice_id: str = invoice.get("id", "") + + if amount_due <= 0: + logger.info( + "handle_subscription_payment_failure: amount_due=%d for user %s;" + " nothing to deduct", + amount_due, + user.id, + ) + return + + credit_model = UserCredit() + try: + await credit_model._add_transaction( + user_id=user.id, + amount=-amount_due, + transaction_type=CreditTransactionType.SUBSCRIPTION, + fail_insufficient_credits=True, + # Use invoice_id as the idempotency key so that Stripe webhook retries + # (e.g. on a transient stripe.Invoice.pay failure) do not double-charge. + transaction_key=invoice_id or None, + metadata=SafeJson( + { + "stripe_customer_id": customer_id, + "stripe_subscription_id": sub_id, + "reason": "subscription_payment_failure_covered_by_balance", + } + ), + ) + # Balance covered the invoice. Pay the Stripe invoice so Stripe's dunning + # system stops retrying it — without this call Stripe would retry automatically + # and re-trigger this webhook, causing double-deductions each retry cycle. + if invoice_id: + try: + await run_in_threadpool(stripe.Invoice.pay, invoice_id) + except stripe.StripeError: + logger.warning( + "handle_subscription_payment_failure: balance deducted for user" + " %s but failed to mark invoice %s as paid; Stripe may retry", + user.id, + invoice_id, + ) + logger.info( + "handle_subscription_payment_failure: deducted %d cents from balance" + " for user %s; Stripe invoice %s paid, sub %s intact, tier preserved", + amount_due, + user.id, + invoice_id, + sub_id, + ) + except InsufficientBalanceError: + # Balance insufficient — cancel Stripe subscription first, then downgrade DB. + # Order matters: if we downgrade the DB first and the Stripe cancel fails, the + # user is permanently stuck on FREE while Stripe continues billing them. + # Cancelling Stripe first is safe: if the DB write then fails, the webhook + # customer.subscription.deleted will fire and correct the tier eventually. + logger.info( + "handle_subscription_payment_failure: insufficient balance for user %s;" + " cancelling Stripe sub %s then downgrading to FREE", + user.id, + sub_id, + ) + try: + await _cancel_customer_subscriptions(customer_id) + except stripe.StripeError: + logger.warning( + "handle_subscription_payment_failure: failed to cancel Stripe sub %s" + " for user %s (customer %s); skipping tier downgrade to avoid" + " inconsistency — Stripe may continue retrying the invoice", + sub_id, + user.id, + customer_id, + ) + return + await set_subscription_tier(user.id, SubscriptionTier.FREE) + + async def admin_get_user_history( page: int = 1, page_size: int = 20, diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py index 34ba19b83c..a9634afcb4 100644 --- a/autogpt_platform/backend/backend/data/credit_subscription_test.py +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -5,12 +5,16 @@ Tests for Stripe-based subscription tier billing. from unittest.mock import AsyncMock, MagicMock, patch import pytest +import stripe from prisma.enums import SubscriptionTier from prisma.models import User from backend.data.credit import ( cancel_stripe_subscription, create_subscription_checkout, + get_proration_credit_cents, + handle_subscription_payment_failure, + modify_stripe_subscription_for_tier, set_subscription_tier, sync_subscription_from_stripe, ) @@ -45,11 +49,18 @@ async def test_set_subscription_tier_downgrade(): await set_subscription_tier("user-1", SubscriptionTier.FREE) +def _make_user(user_id: str = "user-1", tier: SubscriptionTier = SubscriptionTier.FREE): + mock_user = MagicMock(spec=User) + mock_user.id = user_id + mock_user.subscriptionTier = tier + return mock_user + + @pytest.mark.asyncio async def test_sync_subscription_from_stripe_active(): - mock_user = MagicMock(spec=User) - mock_user.id = "user-1" + mock_user = _make_user() stripe_sub = { + "id": "sub_new", "customer": "cus_123", "status": "active", "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, @@ -62,6 +73,10 @@ async def test_sync_subscription_from_stripe_active(): return "price_biz_monthly" return None + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + with ( patch( "backend.data.credit.User.prisma", @@ -71,6 +86,10 @@ async def test_sync_subscription_from_stripe_active(): "backend.data.credit.get_subscription_price_id", side_effect=mock_price_id, ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), patch( "backend.data.credit.set_subscription_tier", new_callable=AsyncMock ) as mock_set, @@ -80,14 +99,59 @@ async def test_sync_subscription_from_stripe_active(): @pytest.mark.asyncio -async def test_sync_subscription_from_stripe_cancelled(): - mock_user = MagicMock(spec=User) - mock_user.id = "user-1" +async def test_sync_subscription_from_stripe_idempotent_no_write_if_unchanged(): + """Stripe retries webhooks; re-sending the same event must not re-write the DB.""" + mock_user = _make_user(tier=SubscriptionTier.PRO) stripe_sub = { + "id": "sub_new", "customer": "cus_123", - "status": "canceled", - "items": {"data": []}, + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_enterprise_not_overwritten(): + """Webhook events must never overwrite an ENTERPRISE tier (admin-managed).""" + mock_user = _make_user(tier=SubscriptionTier.ENTERPRISE) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + with ( patch( "backend.data.credit.User.prisma", @@ -96,11 +160,131 @@ async def test_sync_subscription_from_stripe_cancelled(): patch( "backend.data.credit.set_subscription_tier", new_callable=AsyncMock ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_cancelled(): + """When the only active sub is cancelled, the user is downgraded to FREE.""" + mock_user = _make_user(tier=SubscriptionTier.PRO) + stripe_sub = { + "id": "sub_old", + "customer": "cus_123", + "status": "canceled", + "items": {"data": []}, + } + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, ): await sync_subscription_from_stripe(stripe_sub) mock_set.assert_awaited_once_with("user-1", SubscriptionTier.FREE) +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_cancelled_but_other_active_sub_exists(): + """Cancelling sub_old must NOT downgrade the user if sub_new is still active. + + This covers the race condition where `customer.subscription.deleted` for + the old sub arrives after `customer.subscription.created` for the new sub + was already processed. Unconditionally downgrading to FREE here would + immediately undo the user's upgrade. + """ + mock_user = _make_user(tier=SubscriptionTier.BUSINESS) + stripe_sub = { + "id": "sub_old", + "customer": "cus_123", + "status": "canceled", + "items": {"data": []}, + } + # Stripe still shows sub_new as active for this customer. + active_list = MagicMock() + active_list.data = [{"id": "sub_new"}] + active_list.has_more = False + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + def list_side_effect(*args, **kwargs): + if kwargs.get("status") == "active": + return active_list + return empty_list + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + side_effect=list_side_effect, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + # Must NOT write FREE — another active sub is still present. + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_trialing(): + """status='trialing' should map to the paid tier, same as 'active'.""" + mock_user = _make_user() + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "trialing", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + @pytest.mark.asyncio async def test_sync_subscription_from_stripe_unknown_customer(): stripe_sub = { @@ -116,38 +300,98 @@ async def test_sync_subscription_from_stripe_unknown_customer(): await sync_subscription_from_stripe(stripe_sub) +def _make_user_with_stripe(stripe_customer_id: str | None = "cus_123") -> MagicMock: + """Return a mock model.User with the given stripe_customer_id.""" + mock_user = MagicMock() + mock_user.stripe_customer_id = stripe_customer_id + return mock_user + + @pytest.mark.asyncio async def test_cancel_stripe_subscription_cancels_active(): - mock_sub = {"id": "sub_abc123"} mock_subscriptions = MagicMock() - mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub]) + mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.has_more = False with ( patch( - "backend.data.credit.get_stripe_customer_id", + "backend.data.credit.get_user_by_id", new_callable=AsyncMock, - return_value="cus_123", + return_value=_make_user_with_stripe("cus_123"), ), patch( "backend.data.credit.stripe.Subscription.list", return_value=mock_subscriptions, ), - patch("backend.data.credit.stripe.Subscription.cancel") as mock_cancel, + patch("backend.data.credit.stripe.Subscription.modify") as mock_modify, ): await cancel_stripe_subscription("user-1") - mock_cancel.assert_called_once_with("sub_abc123") + mock_modify.assert_called_once_with("sub_abc123", cancel_at_period_end=True) + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_no_customer_id_returns_false(): + """Users with no stripe_customer_id return False without creating a Stripe customer.""" + result = False + with patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe(stripe_customer_id=None), + ): + result = await cancel_stripe_subscription("user-1") + assert result is False + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_multi_partial_failure(): + """First modify raises → error propagates and subsequent subs are not scheduled.""" + mock_subscriptions = MagicMock() + mock_subscriptions.data = [{"id": "sub_first"}, {"id": "sub_second"}] + mock_subscriptions.has_more = False + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch( + "backend.data.credit.stripe.Subscription.modify", + side_effect=stripe.StripeError("first modify failed"), + ) as mock_modify, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ) as mock_set_tier, + ): + with pytest.raises(stripe.StripeError): + await cancel_stripe_subscription("user-1") + # Only the first modify should have been attempted. + # _cancel_customer_subscriptions has no per-cancel try/except, so the + # StripeError propagates immediately, aborting the loop before sub_second + # is attempted. This is intentional fail-fast behaviour — the caller + # (cancel_stripe_subscription) re-raises and the API handler returns 502. + mock_modify.assert_called_once_with("sub_first", cancel_at_period_end=True) + # DB tier must NOT be updated on the error path — the caller raises + # before reaching set_subscription_tier. + mock_set_tier.assert_not_called() @pytest.mark.asyncio async def test_cancel_stripe_subscription_no_active(): mock_subscriptions = MagicMock() - mock_subscriptions.auto_paging_iter.return_value = iter([]) + mock_subscriptions.data = [] + mock_subscriptions.has_more = False with ( patch( - "backend.data.credit.get_stripe_customer_id", + "backend.data.credit.get_user_by_id", new_callable=AsyncMock, - return_value="cus_123", + return_value=_make_user_with_stripe("cus_123"), ), patch( "backend.data.credit.stripe.Subscription.list", @@ -159,6 +403,139 @@ async def test_cancel_stripe_subscription_no_active(): mock_cancel.assert_not_called() +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_raises_on_list_failure(): + """stripe.Subscription.list() failure propagates so DB tier is not updated.""" + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + side_effect=stripe.StripeError("network error"), + ), + ): + with pytest.raises(stripe.StripeError): + await cancel_stripe_subscription("user-1") + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_cancels_trialing(): + """Trialing subs must also be scheduled for cancellation, else users get billed after trial end.""" + active_subs = MagicMock() + active_subs.data = [] + active_subs.has_more = False + trialing_subs = MagicMock() + trialing_subs.data = [{"id": "sub_trial_123"}] + trialing_subs.has_more = False + + def list_side_effect(*args, **kwargs): + return trialing_subs if kwargs.get("status") == "trialing" else active_subs + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + side_effect=list_side_effect, + ), + patch("backend.data.credit.stripe.Subscription.modify") as mock_modify, + ): + await cancel_stripe_subscription("user-1") + mock_modify.assert_called_once_with("sub_trial_123", cancel_at_period_end=True) + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_cancels_active_and_trialing(): + """Both active AND trialing subs present → both get scheduled for cancellation, no duplicates.""" + active_subs = MagicMock() + active_subs.data = [{"id": "sub_active_1"}] + active_subs.has_more = False + trialing_subs = MagicMock() + trialing_subs.data = [{"id": "sub_trial_2"}] + trialing_subs.has_more = False + + def list_side_effect(*args, **kwargs): + return trialing_subs if kwargs.get("status") == "trialing" else active_subs + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + side_effect=list_side_effect, + ), + patch("backend.data.credit.stripe.Subscription.modify") as mock_modify, + ): + await cancel_stripe_subscription("user-1") + modified_ids = {call.args[0] for call in mock_modify.call_args_list} + assert modified_ids == {"sub_active_1", "sub_trial_2"} + + +@pytest.mark.asyncio +async def test_get_proration_credit_cents_no_stripe_customer_returns_zero(): + """Admin-granted tier users without stripe_customer_id get 0 without creating a customer.""" + with patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe(stripe_customer_id=None), + ) as mock_user: + result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000) + assert result == 0 + mock_user.assert_awaited_once_with("user-1") + + +@pytest.mark.asyncio +async def test_get_proration_credit_cents_zero_cost_returns_zero(): + """FREE tier users (cost=0) return 0 without calling get_user_by_id.""" + with patch( + "backend.data.credit.get_user_by_id", new_callable=AsyncMock + ) as mock_get_user: + result = await get_proration_credit_cents("user-1", monthly_cost_cents=0) + assert result == 0 + mock_get_user.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_proration_credit_cents_with_active_sub(): + """User with active sub returns prorated credit based on remaining billing period.""" + import time + + now = int(time.time()) + period_start = now - 15 * 24 * 3600 # 15 days ago + period_end = now + 15 * 24 * 3600 # 15 days ahead + mock_sub = { + "id": "sub_abc", + "current_period_start": period_start, + "current_period_end": period_end, + } + mock_subs = MagicMock() + mock_subs.data = [mock_sub] + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subs, + ), + ): + result = await get_proration_credit_cents("user-1", monthly_cost_cents=2000) + assert result > 0 + assert result < 2000 + + @pytest.mark.asyncio async def test_create_subscription_checkout_returns_url(): mock_session = MagicMock() @@ -174,7 +551,10 @@ async def test_create_subscription_checkout_returns_url(): new_callable=AsyncMock, return_value="cus_123", ), - patch("stripe.checkout.Session.create", return_value=mock_session), + patch( + "backend.data.credit.stripe.checkout.Session.create", + return_value=mock_session, + ), ): url = await create_subscription_checkout( user_id="user-1", @@ -202,10 +582,31 @@ async def test_create_subscription_checkout_no_price_raises(): @pytest.mark.asyncio -async def test_sync_subscription_from_stripe_unknown_price_defaults_to_free(): - """Unknown price_id should default to FREE instead of returning early.""" - mock_user = MagicMock(spec=User) - mock_user.id = "user-1" +async def test_sync_subscription_from_stripe_missing_customer_key_returns_early(): + """A webhook payload missing 'customer' must not raise KeyError — returns early with a warning.""" + stripe_sub = { + # Omit "customer" entirely — simulates a valid HMAC but malformed payload + "status": "active", + "id": "sub_xyz", + "items": {"data": [{"price": {"id": "price_pro"}}]}, + } + + with ( + patch("backend.data.credit.User.prisma") as mock_prisma, + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + # Should return early without querying the DB or writing a tier + await sync_subscription_from_stripe(stripe_sub) + mock_prisma.assert_not_called() + mock_set.assert_not_called() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_unknown_price_id_preserves_current_tier(): + """Unknown price_id should preserve the current tier, not default to FREE (no DB write).""" + mock_user = _make_user(tier=SubscriptionTier.PRO) stripe_sub = { "customer": "cus_123", "status": "active", @@ -234,10 +635,9 @@ async def test_sync_subscription_from_stripe_unknown_price_defaults_to_free(): @pytest.mark.asyncio -async def test_sync_subscription_from_stripe_none_ld_price_defaults_to_free(): - """When LD returns None for price IDs, active subscription should default to FREE.""" - mock_user = MagicMock(spec=User) - mock_user.id = "user-1" +async def test_sync_subscription_from_stripe_unconfigured_ld_price_preserves_current_tier(): + """When LD flags are unconfigured (None price IDs), the current tier should be preserved, not defaulted to FREE.""" + mock_user = _make_user(tier=SubscriptionTier.PRO) stripe_sub = { "customer": "cus_123", "status": "active", @@ -266,9 +666,9 @@ async def test_sync_subscription_from_stripe_none_ld_price_defaults_to_free(): @pytest.mark.asyncio async def test_sync_subscription_from_stripe_business_tier(): """BUSINESS price_id should map to BUSINESS tier.""" - mock_user = MagicMock(spec=User) - mock_user.id = "user-1" + mock_user = _make_user() stripe_sub = { + "id": "sub_new", "customer": "cus_123", "status": "active", "items": {"data": [{"price": {"id": "price_biz_monthly"}}]}, @@ -281,6 +681,10 @@ async def test_sync_subscription_from_stripe_business_tier(): return "price_biz_monthly" return None + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + with ( patch( "backend.data.credit.User.prisma", @@ -290,6 +694,10 @@ async def test_sync_subscription_from_stripe_business_tier(): "backend.data.credit.get_subscription_price_id", side_effect=mock_price_id, ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), patch( "backend.data.credit.set_subscription_tier", new_callable=AsyncMock ) as mock_set, @@ -298,10 +706,115 @@ async def test_sync_subscription_from_stripe_business_tier(): mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS) +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_cancels_stale_subs(): + """When a new subscription becomes active, older active subs are cancelled. + + Covers the paid-to-paid upgrade case (e.g. PRO → BUSINESS) where Stripe + Checkout creates a new subscription without touching the previous one, + leaving the customer double-billed. + """ + mock_user = _make_user(tier=SubscriptionTier.PRO) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_biz_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + existing = MagicMock() + existing.data = [{"id": "sub_old"}, {"id": "sub_new"}] + existing.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=existing, + ), + patch( + "backend.data.credit.stripe.Subscription.cancel", + ) as mock_cancel, + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS) + # Only the stale sub should be cancelled — never the new one. + mock_cancel.assert_called_once_with("sub_old") + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_stale_cancel_errors_swallowed(): + """Errors cancelling stale subs must not block DB tier update for new sub.""" + import stripe as stripe_mod + + mock_user = _make_user(tier=SubscriptionTier.BUSINESS) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + existing = MagicMock() + existing.data = [{"id": "sub_old"}] + existing.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=existing, + ), + patch( + "backend.data.credit.stripe.Subscription.cancel", + side_effect=stripe_mod.StripeError("cancel failed"), + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + # Must not raise — tier update proceeds even if cleanup cancel fails. + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + @pytest.mark.asyncio async def test_get_subscription_price_id_pro(): from backend.data.credit import get_subscription_price_id + # Clear cached state from other tests to ensure a fresh LD flag lookup. + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] with patch( "backend.data.credit.get_feature_flag_value", new_callable=AsyncMock, @@ -309,12 +822,14 @@ async def test_get_subscription_price_id_pro(): ): price_id = await get_subscription_price_id(SubscriptionTier.PRO) assert price_id == "price_pro_monthly" + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] @pytest.mark.asyncio async def test_get_subscription_price_id_free_returns_none(): from backend.data.credit import get_subscription_price_id + # FREE tier bypasses the LD flag lookup entirely (returns None before fetch). price_id = await get_subscription_price_id(SubscriptionTier.FREE) assert price_id is None @@ -323,6 +838,7 @@ async def test_get_subscription_price_id_free_returns_none(): async def test_get_subscription_price_id_empty_flag_returns_none(): from backend.data.credit import get_subscription_price_id + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] with patch( "backend.data.credit.get_feature_flag_value", new_callable=AsyncMock, @@ -330,31 +846,369 @@ async def test_get_subscription_price_id_empty_flag_returns_none(): ): price_id = await get_subscription_price_id(SubscriptionTier.BUSINESS) assert price_id is None + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] @pytest.mark.asyncio -async def test_cancel_stripe_subscription_handles_stripe_error(): - """Stripe errors during cancellation should be logged, not raised.""" +async def test_get_subscription_price_id_none_not_cached(): + """None returns from transient LD failures are not cached (cache_none=False). + + Without cache_none=False a single LD hiccup would block upgrades for the + full 60-second TTL window because the ``None`` sentinel would be served from + cache on every subsequent call. + """ + from backend.data.credit import get_subscription_price_id + + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] + mock_ld = AsyncMock(side_effect=["", "price_pro_monthly"]) + with patch("backend.data.credit.get_feature_flag_value", mock_ld): + # First call: LD returns empty string → None (transient failure) + first = await get_subscription_price_id(SubscriptionTier.PRO) + assert first is None + # Second call: LD returns the real price ID — must NOT be blocked by cached None + second = await get_subscription_price_id(SubscriptionTier.PRO) + assert second == "price_pro_monthly" + assert mock_ld.call_count == 2 # both calls hit LD (None was not cached) + get_subscription_price_id.cache_clear() # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_raises_on_cancel_error(): + """Stripe errors during period-end scheduling are re-raised so the DB tier is not updated.""" import stripe as stripe_mod - mock_sub = {"id": "sub_abc123"} mock_subscriptions = MagicMock() - mock_subscriptions.auto_paging_iter.return_value = iter([mock_sub]) + mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.has_more = False with ( patch( - "backend.data.credit.get_stripe_customer_id", + "backend.data.credit.get_user_by_id", new_callable=AsyncMock, - return_value="cus_123", + return_value=_make_user_with_stripe("cus_123"), ), patch( "backend.data.credit.stripe.Subscription.list", return_value=mock_subscriptions, ), patch( - "backend.data.credit.stripe.Subscription.cancel", + "backend.data.credit.stripe.Subscription.modify", side_effect=stripe_mod.StripeError("network error"), ), ): - # Should not raise — errors are logged as warnings - await cancel_stripe_subscription("user-1") + with pytest.raises(stripe_mod.StripeError): + await cancel_stripe_subscription("user-1") + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_metadata_user_id_matches(): + """metadata.user_id matching the DB user is accepted and the tier is updated normally.""" + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.FREE) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "metadata": {"user_id": "user-1"}, + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro_monthly" if tier == SubscriptionTier.PRO else None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_metadata_user_id_mismatch_blocked(): + """metadata.user_id mismatching the DB user must block the tier update. + + A customer↔user mapping inconsistency (e.g. a customer ID reassigned or + a corrupted DB row) must never silently update the wrong user's tier. + """ + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.FREE) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "metadata": {"user_id": "user-different"}, + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + # Mismatch → must not update any tier + mock_set.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_no_metadata_user_id_skips_check(): + """Absence of metadata.user_id (e.g. subs created outside Checkout) skips the cross-check.""" + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.FREE) + stripe_sub = { + "id": "sub_new", + "customer": "cus_123", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + # No "metadata" key at all + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return "price_pro_monthly" if tier == SubscriptionTier.PRO else None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + # No metadata → cross-check skipped → tier updated normally + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_handle_subscription_payment_failure_balance_covers_pays_invoice(): + """When balance covers the invoice, Stripe Invoice.pay is called to stop retries.""" + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO) + invoice = { + "id": "in_abc123", + "customer": "cus_123", + "subscription": "sub_abc123", + "amount_due": 2000, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.UserCredit._add_transaction", + new_callable=AsyncMock, + ), + patch("backend.data.credit.stripe.Invoice.pay") as mock_pay, + ): + await handle_subscription_payment_failure(invoice) + mock_pay.assert_called_once_with("in_abc123") + + +@pytest.mark.asyncio +async def test_handle_subscription_payment_failure_invoice_pay_error_does_not_raise(): + """Failure to mark the invoice as paid is logged but does not propagate.""" + import stripe as stripe_mod + + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO) + invoice = { + "id": "in_abc123", + "customer": "cus_123", + "subscription": "sub_abc123", + "amount_due": 2000, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.UserCredit._add_transaction", + new_callable=AsyncMock, + ), + patch( + "backend.data.credit.stripe.Invoice.pay", + side_effect=stripe_mod.StripeError("network error"), + ), + ): + # Must not raise — the pay failure is only logged as a warning + await handle_subscription_payment_failure(invoice) + + +@pytest.mark.asyncio +async def test_handle_subscription_payment_failure_passes_invoice_id_as_transaction_key(): + """invoice_id is used as the idempotency key to prevent double-charging on webhook retries.""" + mock_user = _make_user(user_id="user-1", tier=SubscriptionTier.PRO) + invoice = { + "id": "in_idempotency_test", + "customer": "cus_123", + "subscription": "sub_abc123", + "amount_due": 2000, + } + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.UserCredit._add_transaction", + new_callable=AsyncMock, + ) as mock_add_tx, + patch("backend.data.credit.stripe.Invoice.pay"), + ): + await handle_subscription_payment_failure(invoice) + mock_add_tx.assert_called_once() + _, kwargs = mock_add_tx.call_args + assert kwargs.get("transaction_key") == "in_idempotency_test" + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): + """modify_stripe_subscription_for_tier calls Subscription.modify and returns True.""" + mock_sub = { + "id": "sub_abc", + "items": {"data": [{"id": "si_abc"}]}, + } + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify", + ) as mock_modify, + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + mock_modify.assert_called_once_with( + "sub_abc", + items=[{"id": "si_abc", "price": "price_pro_monthly"}], + proration_behavior="create_prorations", + ) + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_returns_false_when_no_customer_id(): + """modify_stripe_subscription_for_tier returns False when user has no Stripe customer ID. + + Admin-granted paid tiers have no Stripe customer record. Calling + get_stripe_customer_id would create an orphaned customer if a subsequent API call + fails, so the function returns False early and the API layer falls back to Checkout. + """ + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = None + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_returns_false_when_no_sub(): + """modify_stripe_subscription_for_tier returns False when no active subscription exists.""" + mock_list = MagicMock() + mock_list.data = [] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_list, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_raises_on_missing_price_id(): + """modify_stripe_subscription_for_tier raises ValueError when no price ID is configured.""" + with patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ): + with pytest.raises(ValueError, match="No Stripe price ID configured"): + await modify_stripe_subscription_for_tier("user-1", SubscriptionTier.PRO) diff --git a/autogpt_platform/backend/backend/util/cache.py b/autogpt_platform/backend/backend/util/cache.py index d813a42211..8f55d49fdc 100644 --- a/autogpt_platform/backend/backend/util/cache.py +++ b/autogpt_platform/backend/backend/util/cache.py @@ -73,6 +73,31 @@ def _get_redis() -> Redis: return r +class _MissingType: + """Singleton sentinel type — distinct from ``None`` (a valid cached value). + + Using a dedicated class (instead of ``Any = object()``) lets mypy prove + that comparisons ``result is _MISSING`` narrow the type correctly and + prevents accidental use of the sentinel where a real value is expected. + """ + + _instance: "_MissingType | None" = None + + def __new__(cls) -> "_MissingType": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "" + + +# Sentinel returned by ``_get_from_memory`` / ``_get_from_redis`` to mean +# "no entry exists" — distinct from a cached ``None`` value, which is a +# valid result for callers that opt into caching it. +_MISSING = _MissingType() + + @dataclass class CachedValue: """Wrapper for cached values with timestamp to avoid tuple ambiguity.""" @@ -160,6 +185,7 @@ def cached( ttl_seconds: int, shared_cache: bool = False, refresh_ttl_on_get: bool = False, + cache_none: bool = True, ) -> Callable[[Callable[P, R]], CachedFunction[P, R]]: """ Thundering herd safe cache decorator for both sync and async functions. @@ -172,6 +198,10 @@ def cached( ttl_seconds: Time to live in seconds. Required - entries must expire. shared_cache: If True, use Redis for cross-process caching refresh_ttl_on_get: If True, refresh TTL when cache entry is accessed (LRU behavior) + cache_none: If True (default) ``None`` is cached like any other value. + Set to ``False`` for functions that return ``None`` to signal a + transient error and should be re-tried on the next call without + poisoning the cache (e.g. external API calls that may fail). Returns: Decorated function with caching capabilities @@ -184,6 +214,12 @@ def cached( @cached(ttl_seconds=600, shared_cache=True, refresh_ttl_on_get=True) async def expensive_async_operation(param: str) -> dict: return {"result": param} + + @cached(ttl_seconds=300, cache_none=False) + async def fetch_external(id: str) -> dict | None: + # Returns None on transient error — won't be stored, + # next call retries instead of returning the stale None. + ... """ def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]: @@ -191,9 +227,14 @@ def cached( cache_storage: dict[tuple, CachedValue] = {} _event_loop_locks: dict[Any, asyncio.Lock] = {} - def _get_from_redis(redis_key: str) -> Any | None: + def _get_from_redis(redis_key: str) -> Any: """Get value from Redis, optionally refreshing TTL. + Returns the cached value (which may be ``None``) on a hit, or the + module-level ``_MISSING`` sentinel on a miss / corrupt entry. + Callers must compare with ``is _MISSING`` so cached ``None`` values + are not mistaken for misses. + Values are expected to carry an HMAC-SHA256 prefix for integrity verification. Unsigned (legacy) or tampered entries are silently discarded and treated as cache misses, so the caller recomputes and @@ -213,11 +254,11 @@ def cached( f"for {func_name}, discarding entry: " "possible tampering or legacy unsigned value" ) - return None + return _MISSING return pickle.loads(payload) except Exception as e: logger.error(f"Redis error during cache check for {func_name}: {e}") - return None + return _MISSING def _set_to_redis(redis_key: str, value: Any) -> None: """Set HMAC-signed pickled value in Redis with TTL.""" @@ -227,8 +268,13 @@ def cached( except Exception as e: logger.error(f"Redis error storing cache for {func_name}: {e}") - def _get_from_memory(key: tuple) -> Any | None: - """Get value from in-memory cache, checking TTL.""" + def _get_from_memory(key: tuple) -> Any: + """Get value from in-memory cache, checking TTL. + + Returns the cached value (which may be ``None``) on a hit, or the + ``_MISSING`` sentinel on a miss / TTL expiry. See + ``_get_from_redis`` for the rationale. + """ if key in cache_storage: cached_data = cache_storage[key] if time.time() - cached_data.timestamp < ttl_seconds: @@ -236,7 +282,7 @@ def cached( f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}" ) return cached_data.result - return None + return _MISSING def _set_to_memory(key: tuple, value: Any) -> None: """Set value in in-memory cache with timestamp.""" @@ -270,11 +316,11 @@ def cached( # Fast path: check cache without lock if shared_cache: result = _get_from_redis(redis_key) - if result is not None: + if result is not _MISSING: return result else: result = _get_from_memory(key) - if result is not None: + if result is not _MISSING: return result # Slow path: acquire lock for cache miss/expiry @@ -282,22 +328,24 @@ def cached( # Double-check: another coroutine might have populated cache if shared_cache: result = _get_from_redis(redis_key) - if result is not None: + if result is not _MISSING: return result else: result = _get_from_memory(key) - if result is not None: + if result is not _MISSING: return result # Cache miss - execute function logger.debug(f"Cache miss for {func_name}") result = await target_func(*args, **kwargs) - # Store result - if shared_cache: - _set_to_redis(redis_key, result) - else: - _set_to_memory(key, result) + # Store result (skip ``None`` if the caller opted out of + # caching it — used for transient-error sentinels). + if cache_none or result is not None: + if shared_cache: + _set_to_redis(redis_key, result) + else: + _set_to_memory(key, result) return result @@ -315,11 +363,11 @@ def cached( # Fast path: check cache without lock if shared_cache: result = _get_from_redis(redis_key) - if result is not None: + if result is not _MISSING: return result else: result = _get_from_memory(key) - if result is not None: + if result is not _MISSING: return result # Slow path: acquire lock for cache miss/expiry @@ -327,22 +375,24 @@ def cached( # Double-check: another thread might have populated cache if shared_cache: result = _get_from_redis(redis_key) - if result is not None: + if result is not _MISSING: return result else: result = _get_from_memory(key) - if result is not None: + if result is not _MISSING: return result # Cache miss - execute function logger.debug(f"Cache miss for {func_name}") result = target_func(*args, **kwargs) - # Store result - if shared_cache: - _set_to_redis(redis_key, result) - else: - _set_to_memory(key, result) + # Store result (skip ``None`` if the caller opted out of + # caching it — used for transient-error sentinels). + if cache_none or result is not None: + if shared_cache: + _set_to_redis(redis_key, result) + else: + _set_to_memory(key, result) return result diff --git a/autogpt_platform/backend/backend/util/cache_test.py b/autogpt_platform/backend/backend/util/cache_test.py index ee752152ff..0ee41f948f 100644 --- a/autogpt_platform/backend/backend/util/cache_test.py +++ b/autogpt_platform/backend/backend/util/cache_test.py @@ -1223,3 +1223,123 @@ class TestCacheHMAC: assert call_count == 2 legacy_test_fn.cache_clear() + + +class TestCacheNoneHandling: + """Tests for the ``cache_none`` parameter on the @cached decorator. + + Sentry bug PRRT_kwDOJKSTjM56RTEu (HIGH): the cache previously could not + distinguish "no entry" from "entry is None", so any function returning + ``None`` was effectively re-executed on every call. The fix is a + sentinel-based check inside the wrappers, plus an opt-out + ``cache_none=False`` flag for callers that *want* errors to retry. + """ + + @pytest.mark.asyncio + async def test_async_none_is_cached_by_default(self): + """With ``cache_none=True`` (default), cached ``None`` is returned + from the cache instead of triggering re-execution.""" + call_count = 0 + + @cached(ttl_seconds=300) + async def maybe_none(x: int) -> int | None: + nonlocal call_count + call_count += 1 + return None + + assert await maybe_none(1) is None + assert call_count == 1 + + # Second call should hit the cache, not re-execute. + assert await maybe_none(1) is None + assert call_count == 1 + + # Different argument is a different cache key — re-executes. + assert await maybe_none(2) is None + assert call_count == 2 + + def test_sync_none_is_cached_by_default(self): + call_count = 0 + + @cached(ttl_seconds=300) + def maybe_none(x: int) -> int | None: + nonlocal call_count + call_count += 1 + return None + + assert maybe_none(1) is None + assert maybe_none(1) is None + assert call_count == 1 + + @pytest.mark.asyncio + async def test_async_cache_none_false_skips_storing_none(self): + """``cache_none=False`` skips storing ``None`` so transient errors + are retried on the next call instead of poisoning the cache.""" + call_count = 0 + results: list[int | None] = [None, None, 42] + + @cached(ttl_seconds=300, cache_none=False) + async def maybe_none(x: int) -> int | None: + nonlocal call_count + result = results[call_count] + call_count += 1 + return result + + # First call: returns None, NOT stored. + assert await maybe_none(1) is None + assert call_count == 1 + + # Second call with same key: re-executes (None wasn't cached). + assert await maybe_none(1) is None + assert call_count == 2 + + # Third call: returns 42, this time it IS stored. + assert await maybe_none(1) == 42 + assert call_count == 3 + + # Fourth call: cache hit on the stored 42. + assert await maybe_none(1) == 42 + assert call_count == 3 + + def test_sync_cache_none_false_skips_storing_none(self): + call_count = 0 + results: list[int | None] = [None, 99] + + @cached(ttl_seconds=300, cache_none=False) + def maybe_none(x: int) -> int | None: + nonlocal call_count + result = results[call_count] + call_count += 1 + return result + + assert maybe_none(1) is None + assert call_count == 1 + + # None was not stored — re-executes. + assert maybe_none(1) == 99 + assert call_count == 2 + + # 99 IS stored — no re-execution. + assert maybe_none(1) == 99 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_async_shared_cache_none_is_cached_by_default(self): + """Shared (Redis) cache also properly returns cached ``None`` values.""" + call_count = 0 + + @cached(ttl_seconds=30, shared_cache=True) + async def maybe_none_redis(x: int) -> int | None: + nonlocal call_count + call_count += 1 + return None + + maybe_none_redis.cache_clear() + + assert await maybe_none_redis(1) is None + assert call_count == 1 + + assert await maybe_none_redis(1) is None + assert call_count == 1 + + maybe_none_redis.cache_clear() diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index 27121304ca..c341666cdb 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -1,6 +1,7 @@ import contextlib import logging import os +import uuid from enum import Enum from functools import wraps from typing import Any, Awaitable, Callable, TypeVar @@ -101,6 +102,12 @@ async def _fetch_user_context_data(user_id: str) -> Context: """ builder = Context.builder(user_id).kind("user").anonymous(True) + try: + uuid.UUID(user_id) + except ValueError: + # Non-UUID key (e.g. "system") — skip Supabase lookup, return anonymous context. + return builder.build() + try: from backend.util.clients import get_supabase diff --git a/autogpt_platform/backend/scripts/download_transcripts.py b/autogpt_platform/backend/scripts/download_transcripts.py index 26204c3243..a9b32e8494 100644 --- a/autogpt_platform/backend/scripts/download_transcripts.py +++ b/autogpt_platform/backend/scripts/download_transcripts.py @@ -88,17 +88,19 @@ async def cmd_download(session_ids: list[str]) -> None: print(f"[{sid[:12]}] Not found in GCS") continue + content_str = ( + dl.content.decode("utf-8") if isinstance(dl.content, bytes) else dl.content + ) out = _transcript_path(sid) with open(out, "w") as f: - f.write(dl.content) + f.write(content_str) - lines = len(dl.content.strip().split("\n")) + lines = len(content_str.strip().split("\n")) meta = { "session_id": sid, "user_id": user_id, "message_count": dl.message_count, - "uploaded_at": dl.uploaded_at, - "transcript_bytes": len(dl.content), + "transcript_bytes": len(content_str), "transcript_lines": lines, } with open(_meta_path(sid), "w") as f: @@ -106,7 +108,7 @@ async def cmd_download(session_ids: list[str]) -> None: print( f"[{sid[:12]}] Saved: {lines} entries, " - f"{len(dl.content)} bytes, msg_count={dl.message_count}" + f"{len(content_str)} bytes, msg_count={dl.message_count}" ) print("\nDone. Run 'load' command to import into local dev environment.") @@ -227,7 +229,7 @@ async def cmd_load(session_ids: list[str]) -> None: await upload_transcript( user_id=user_id, session_id=sid, - content=content, + content=content.encode("utf-8"), message_count=msg_count, ) print(f"[{sid[:12]}] Stored transcript in local workspace storage") diff --git a/autogpt_platform/backend/test/copilot/test_transcript_watermark.py b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py new file mode 100644 index 0000000000..bd88726339 --- /dev/null +++ b/autogpt_platform/backend/test/copilot/test_transcript_watermark.py @@ -0,0 +1,140 @@ +"""Unit tests for the transcript watermark (message_count) fix. + +The bug: upload used message_count=len(session.messages) (DB count). When a +prior turn's GCS upload failed silently, the JSONL on GCS was stale (e.g. +covered only T1-T12) but the meta.json watermark matched the full DB count +(e.g. 46). The next turn's gap-fill check (transcript_msg_count < msg_count-1) +never triggered, so the model silently lost context for the skipped turns. + +The fix: watermark = previous_coverage + 2 (current user+asst pair) when +use_resume=True and transcript_msg_count > 0. This ensures the watermark +reflects the JSONL content, not the DB count. + +These tests exercise _build_query_message directly to verify that gap-fill +triggers with the corrected watermark but NOT with the inflated (buggy) one. +""" + +from unittest.mock import MagicMock + +import pytest + +from backend.copilot.sdk.service import _build_query_message + + +def _make_messages(n_pairs: int, *, current_user: str = "current") -> list[MagicMock]: + """Build a flat list of n_pairs*2 alternating user/asst messages, plus + one trailing user message for the *current* turn.""" + msgs: list[MagicMock] = [] + for i in range(n_pairs): + u = MagicMock() + u.role = "user" + u.content = f"user message {i}" + a = MagicMock() + a.role = "assistant" + a.content = f"assistant response {i}" + msgs.extend([u, a]) + # Current turn's user message + cur = MagicMock() + cur.role = "user" + cur.content = current_user + msgs.append(cur) + return msgs + + +def _make_session(messages: list[MagicMock]) -> MagicMock: + session = MagicMock() + session.messages = messages + return session + + +@pytest.mark.asyncio +async def test_gap_fill_triggers_for_stale_jsonl(): + """Scenario: T1-T12 in JSONL (watermark=24), DB has T1-T22+Test (46 msgs). + + With the FIX: 'Test' uploaded watermark=26 (T12's 24 + 2 for 'Test'). + Next turn (T24) downloads watermark=26, DB has 47. + Gap check: 26 < 47-1=46 → TRUE → gap fills T14-T23. + """ + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="memory test - recall all") + assert len(msgs) == 47 + + session = _make_session(msgs) + + # Watermark as uploaded by the FIX: T12 covered 24, 'Test' +2 = 26 + result_msg, _ = await _build_query_message( + current_message="memory test - recall all", + session=session, + use_resume=True, + transcript_msg_count=26, + session_id="test-session-id", + ) + + assert "" in result_msg, ( + "Expected gap-fill to inject when " + "watermark=26 < msg_count-1=46" + ) + + +@pytest.mark.asyncio +async def test_no_gap_fill_when_watermark_is_current(): + """When the JSONL is fully current (watermark = DB-1), no gap injected.""" + # T23 turns in DB (46 messages) + T24 user = 47 + msgs = _make_messages(23, current_user="next message") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="next message", + session=session, + use_resume=True, + transcript_msg_count=46, # current — no gap + session_id="test-session-id", + ) + + assert ( + "" not in result_msg + ), "No gap-fill expected when watermark is current" + assert result_msg == "next message" + + +@pytest.mark.asyncio +async def test_inflated_watermark_suppresses_gap_fill(): + """Documents the original bug: inflated watermark suppresses gap-fill. + + 'Test' uploaded watermark=len(session.messages)=46 even though only 26 + messages are in the JSONL. Next turn: 46 < 47-1=46 → FALSE → no gap fill. + """ + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + # Buggy watermark: inflated to DB count + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=46, # inflated — suppresses gap fill + session_id="test-session-id", + ) + + assert ( + "" not in result_msg + ), "With inflated watermark, gap-fill is suppressed — this documents the bug" + + +@pytest.mark.asyncio +async def test_fixed_watermark_fills_same_gap(): + """Same scenario but with the FIXED watermark triggers gap-fill.""" + msgs = _make_messages(23, current_user="memory test") + session = _make_session(msgs) + + result_msg, _ = await _build_query_message( + current_message="memory test", + session=session, + use_resume=True, + transcript_msg_count=26, # fixed watermark + session_id="test-session-id", + ) + + assert ( + "" in result_msg + ), "With fixed watermark=26, gap-fill triggers and injects missing turns" diff --git a/autogpt_platform/frontend/package.json b/autogpt_platform/frontend/package.json index 4661ab2050..292e64e8dd 100644 --- a/autogpt_platform/frontend/package.json +++ b/autogpt_platform/frontend/package.json @@ -155,6 +155,7 @@ "@types/twemoji": "13.1.2", "@vitejs/plugin-react": "5.1.2", "@vitest/coverage-v8": "4.0.17", + "agentation": "3.0.2", "axe-playwright": "2.2.2", "chromatic": "13.3.3", "concurrently": "9.2.1", diff --git a/autogpt_platform/frontend/pnpm-lock.yaml b/autogpt_platform/frontend/pnpm-lock.yaml index 057719def1..ad6429ac52 100644 --- a/autogpt_platform/frontend/pnpm-lock.yaml +++ b/autogpt_platform/frontend/pnpm-lock.yaml @@ -376,6 +376,9 @@ importers: '@vitest/coverage-v8': specifier: 4.0.17 version: 4.0.17(vitest@4.0.17(@opentelemetry/api@1.9.0)(@types/node@24.10.0)(happy-dom@20.3.4)(jiti@2.6.1)(jsdom@27.4.0)(msw@2.11.6(@types/node@24.10.0)(typescript@5.9.3))(terser@5.44.1)(yaml@2.8.2)) + agentation: + specifier: 3.0.2 + version: 3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1) axe-playwright: specifier: 2.2.2 version: 2.2.2(playwright@1.56.1) @@ -4119,6 +4122,17 @@ packages: resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==} engines: {node: '>= 14'} + agentation@3.0.2: + resolution: {integrity: sha512-iGzBxFVTuZEIKzLY6AExSLAQH6i6SwxV4pAu7v7m3X6bInZ7qlZXAwrEqyc4+EfP4gM7z2RXBF6SF4DeH0f2lA==} + peerDependencies: + react: '>=18.0.0' + react-dom: '>=18.0.0' + peerDependenciesMeta: + react: + optional: true + react-dom: + optional: true + ai@6.0.134: resolution: {integrity: sha512-YalNEaavld/kE444gOcsMKXdVVRGEe0SK77fAFcWYcqLg+a7xKnEet8bdfrEAJTfnMjj01rhgrIL10903w1a5Q==} engines: {node: '>=18'} @@ -13119,6 +13133,11 @@ snapshots: agent-base@7.1.4: optional: true + agentation@3.0.2(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + optionalDependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + ai@6.0.134(zod@3.25.76): dependencies: '@ai-sdk/gateway': 3.0.77(zod@3.25.76) diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx index 3a55fabf1d..186c8d96fe 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/Flow/Flow.tsx @@ -110,7 +110,7 @@ export const Flow = () => { event.preventDefault(); }} maxZoom={2} - minZoom={0.1} + minZoom={0.05} onDragOver={onDragOver} onDrop={onDrop} nodesDraggable={!isLocked} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx index 88f70c75d8..62255037eb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx @@ -93,6 +93,7 @@ export function CopilotPage() { hasMoreMessages, isLoadingMore, loadMore, + forwardPaginated, // Mobile drawer isMobile, isDrawerOpen, @@ -217,6 +218,7 @@ export function CopilotPage() { hasMoreMessages={hasMoreMessages} isLoadingMore={isLoadingMore} onLoadMore={loadMore} + forwardPaginated={forwardPaginated} droppedFiles={droppedFiles} onDroppedFilesConsumed={handleDroppedFilesConsumed} historicalDurations={historicalDurations} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useChatSession.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useChatSession.test.ts new file mode 100644 index 0000000000..a6d8c5e896 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useChatSession.test.ts @@ -0,0 +1,122 @@ +import { renderHook } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { useChatSession } from "../useChatSession"; + +const mockUseGetV2GetSession = vi.fn(); + +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + useGetV2GetSession: (...args: unknown[]) => mockUseGetV2GetSession(...args), + usePostV2CreateSession: () => ({ mutateAsync: vi.fn(), isPending: false }), + getGetV2GetSessionQueryKey: (id: string) => ["session", id], + getGetV2ListSessionsQueryKey: () => ["sessions"], +})); + +vi.mock("@tanstack/react-query", () => ({ + useQueryClient: () => ({ + invalidateQueries: vi.fn(), + setQueryData: vi.fn(), + }), +})); + +vi.mock("nuqs", () => ({ + parseAsString: { withDefault: (v: unknown) => v }, + useQueryState: () => ["sess-1", vi.fn()], +})); + +vi.mock("../helpers/convertChatSessionToUiMessages", () => ({ + convertChatSessionMessagesToUiMessages: vi.fn(() => ({ + messages: [], + historicalDurations: new Map(), + })), +})); + +vi.mock("../helpers", () => ({ + resolveSessionDryRun: vi.fn(() => false), +})); + +vi.mock("@sentry/nextjs", () => ({ + captureException: vi.fn(), +})); + +function makeQueryResult(data: object | null) { + return { + data: data ? { status: 200, data } : undefined, + isLoading: false, + isError: false, + isFetching: false, + refetch: vi.fn(), + }; +} + +describe("useChatSession — newestSequence and forwardPaginated", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns null / false when no session data", () => { + mockUseGetV2GetSession.mockReturnValue(makeQueryResult(null)); + const { result } = renderHook(() => useChatSession()); + expect(result.current.newestSequence).toBeNull(); + expect(result.current.forwardPaginated).toBe(false); + }); + + it("returns newestSequence from session data", () => { + mockUseGetV2GetSession.mockReturnValue( + makeQueryResult({ + messages: [], + has_more_messages: true, + oldest_sequence: 0, + newest_sequence: 99, + forward_paginated: false, + active_stream: null, + }), + ); + const { result } = renderHook(() => useChatSession()); + expect(result.current.newestSequence).toBe(99); + }); + + it("returns null for newestSequence when field is missing", () => { + mockUseGetV2GetSession.mockReturnValue( + makeQueryResult({ + messages: [], + has_more_messages: false, + oldest_sequence: 0, + newest_sequence: null, + forward_paginated: false, + active_stream: null, + }), + ); + const { result } = renderHook(() => useChatSession()); + expect(result.current.newestSequence).toBeNull(); + }); + + it("returns forwardPaginated=true when session is forward-paginated", () => { + mockUseGetV2GetSession.mockReturnValue( + makeQueryResult({ + messages: [], + has_more_messages: true, + oldest_sequence: 0, + newest_sequence: 49, + forward_paginated: true, + active_stream: null, + }), + ); + const { result } = renderHook(() => useChatSession()); + expect(result.current.forwardPaginated).toBe(true); + }); + + it("returns forwardPaginated=false when session is backward-paginated", () => { + mockUseGetV2GetSession.mockReturnValue( + makeQueryResult({ + messages: [], + has_more_messages: true, + oldest_sequence: 50, + newest_sequence: 99, + forward_paginated: false, + active_stream: null, + }), + ); + const { result } = renderHook(() => useChatSession()); + expect(result.current.forwardPaginated).toBe(false); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotPage.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotPage.test.ts new file mode 100644 index 0000000000..cd23a51195 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotPage.test.ts @@ -0,0 +1,202 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { useCopilotPage } from "../useCopilotPage"; + +const mockUseChatSession = vi.fn(); +const mockUseCopilotStream = vi.fn(); +const mockUseLoadMoreMessages = vi.fn(); + +vi.mock("../useChatSession", () => ({ + useChatSession: (...args: unknown[]) => mockUseChatSession(...args), +})); +vi.mock("../useCopilotStream", () => ({ + useCopilotStream: (...args: unknown[]) => mockUseCopilotStream(...args), +})); +vi.mock("../useLoadMoreMessages", () => ({ + useLoadMoreMessages: (...args: unknown[]) => mockUseLoadMoreMessages(...args), +})); +vi.mock("../useCopilotNotifications", () => ({ + useCopilotNotifications: () => undefined, +})); +vi.mock("../useWorkflowImportAutoSubmit", () => ({ + useWorkflowImportAutoSubmit: () => undefined, +})); +vi.mock("../store", () => ({ + useCopilotUIStore: () => ({ + sessionToDelete: null, + setSessionToDelete: vi.fn(), + isDrawerOpen: false, + setDrawerOpen: vi.fn(), + copilotChatMode: "chat", + copilotLlmModel: null, + isDryRun: false, + }), +})); +vi.mock("../helpers/convertChatSessionToUiMessages", () => ({ + concatWithAssistantMerge: (a: unknown[], b: unknown[]) => [...a, ...b], +})); +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + useDeleteV2DeleteSession: () => ({ mutate: vi.fn(), isPending: false }), + useGetV2ListSessions: () => ({ data: undefined, isLoading: false }), + getGetV2ListSessionsQueryKey: () => ["sessions"], +})); +vi.mock("@/components/molecules/Toast/use-toast", () => ({ + toast: vi.fn(), +})); +vi.mock("@/lib/direct-upload", () => ({ + uploadFileDirect: vi.fn(), +})); +vi.mock("@/lib/hooks/useBreakpoint", () => ({ + useBreakpoint: () => "lg", +})); +vi.mock("@/lib/supabase/hooks/useSupabase", () => ({ + useSupabase: () => ({ isUserLoading: false, isLoggedIn: true }), +})); +vi.mock("@tanstack/react-query", () => ({ + useQueryClient: () => ({ invalidateQueries: vi.fn() }), +})); +vi.mock("@/services/feature-flags/use-get-flag", () => ({ + Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" }, + useGetFlag: () => false, +})); + +function makeBaseChatSession(overrides: Record = {}) { + return { + sessionId: "sess-1", + setSessionId: vi.fn(), + hydratedMessages: [], + rawSessionMessages: [], + historicalDurations: new Map(), + hasActiveStream: false, + hasMoreMessages: false, + oldestSequence: null, + newestSequence: null, + forwardPaginated: false, + isLoadingSession: false, + isSessionError: false, + createSession: vi.fn(), + isCreatingSession: false, + refetchSession: vi.fn(), + sessionDryRun: false, + ...overrides, + }; +} + +function makeBaseCopilotStream(overrides: Record = {}) { + return { + messages: [], + sendMessage: vi.fn(), + stop: vi.fn(), + status: "ready", + error: undefined, + isReconnecting: false, + isSyncing: false, + isUserStoppingRef: { current: false }, + rateLimitMessage: null, + dismissRateLimit: vi.fn(), + ...overrides, + }; +} + +function makeBaseLoadMore(overrides: Record = {}) { + return { + pagedMessages: [], + hasMore: false, + isLoadingMore: false, + loadMore: vi.fn(), + resetPaged: vi.fn(), + ...overrides, + }; +} + +describe("useCopilotPage — forwardPaginated message ordering", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("prepends pagedMessages before currentMessages when forwardPaginated=false", () => { + const pagedMsg = { id: "paged", role: "user" }; + const currentMsg = { id: "current", role: "assistant" }; + mockUseChatSession.mockReturnValue( + makeBaseChatSession({ forwardPaginated: false }), + ); + mockUseCopilotStream.mockReturnValue( + makeBaseCopilotStream({ messages: [currentMsg] }), + ); + mockUseLoadMoreMessages.mockReturnValue( + makeBaseLoadMore({ pagedMessages: [pagedMsg] }), + ); + + const { result } = renderHook(() => useCopilotPage()); + + // Backward: pagedMessages (older) come first + expect(result.current.messages[0]).toEqual(pagedMsg); + expect(result.current.messages[1]).toEqual(currentMsg); + }); + + it("appends pagedMessages after currentMessages when forwardPaginated=true", () => { + const pagedMsg = { id: "paged", role: "assistant" }; + const currentMsg = { id: "current", role: "user" }; + mockUseChatSession.mockReturnValue( + makeBaseChatSession({ forwardPaginated: true }), + ); + mockUseCopilotStream.mockReturnValue( + makeBaseCopilotStream({ messages: [currentMsg] }), + ); + mockUseLoadMoreMessages.mockReturnValue( + makeBaseLoadMore({ pagedMessages: [pagedMsg] }), + ); + + const { result } = renderHook(() => useCopilotPage()); + + // Forward: currentMessages (beginning of session) come first + expect(result.current.messages[0]).toEqual(currentMsg); + expect(result.current.messages[1]).toEqual(pagedMsg); + }); + + it("calls resetPaged when forwardPaginated transitions false→true with paged messages", async () => { + const mockResetPaged = vi.fn(); + const pagedMsg = { id: "paged", role: "user" }; + + mockUseChatSession.mockReturnValue( + makeBaseChatSession({ forwardPaginated: false }), + ); + mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream()); + mockUseLoadMoreMessages.mockReturnValue( + makeBaseLoadMore({ + pagedMessages: [pagedMsg], + resetPaged: mockResetPaged, + }), + ); + + const { rerender } = renderHook(() => useCopilotPage()); + + // Simulate session completing — forwardPaginated flips to true + mockUseChatSession.mockReturnValue( + makeBaseChatSession({ forwardPaginated: true }), + ); + + act(() => { + rerender(); + }); + + await waitFor(() => { + expect(mockResetPaged).toHaveBeenCalled(); + }); + }); + + it("does not call resetPaged when forwardPaginated is already true on mount", () => { + const mockResetPaged = vi.fn(); + mockUseChatSession.mockReturnValue( + makeBaseChatSession({ forwardPaginated: true }), + ); + mockUseCopilotStream.mockReturnValue(makeBaseCopilotStream()); + mockUseLoadMoreMessages.mockReturnValue( + makeBaseLoadMore({ pagedMessages: [], resetPaged: mockResetPaged }), + ); + + renderHook(() => useCopilotPage()); + + expect(mockResetPaged).not.toHaveBeenCalled(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useLoadMoreMessages.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useLoadMoreMessages.test.ts new file mode 100644 index 0000000000..8f781e5f46 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useLoadMoreMessages.test.ts @@ -0,0 +1,568 @@ +import { act, renderHook, waitFor } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { useLoadMoreMessages } from "../useLoadMoreMessages"; + +const mockGetV2GetSession = vi.fn(); + +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + getV2GetSession: (...args: unknown[]) => mockGetV2GetSession(...args), +})); + +vi.mock("../helpers/convertChatSessionToUiMessages", () => ({ + convertChatSessionMessagesToUiMessages: vi.fn(() => ({ messages: [] })), + extractToolOutputsFromRaw: vi.fn(() => []), +})); + +const BASE_ARGS = { + sessionId: "sess-1", + initialOldestSequence: 0, + initialNewestSequence: 49, + initialHasMore: true, + forwardPaginated: true, + initialPageRawMessages: [], +}; + +function makeSuccessResponse(overrides: { + messages?: unknown[]; + has_more_messages?: boolean; + oldest_sequence?: number; + newest_sequence?: number; +}) { + return { + status: 200, + data: { + messages: overrides.messages ?? [], + has_more_messages: overrides.has_more_messages ?? false, + oldest_sequence: overrides.oldest_sequence ?? 0, + newest_sequence: overrides.newest_sequence ?? 49, + }, + }; +} + +describe("useLoadMoreMessages", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("initialises with empty pagedMessages and correct cursors", () => { + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + expect(result.current.pagedMessages).toHaveLength(0); + expect(result.current.hasMore).toBe(true); + expect(result.current.isLoadingMore).toBe(false); + }); + + it("resetPaged clears paged state and sets hasMore=false during transition", () => { + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + act(() => { + result.current.resetPaged(); + }); + + expect(result.current.pagedMessages).toHaveLength(0); + // hasMore must be false during transition to prevent forward loadMore + // from firing on the now-active session before forwardPaginated updates. + expect(result.current.hasMore).toBe(false); + expect(result.current.isLoadingMore).toBe(false); + }); + + it("resetPaged exposes a fresh loadMore via incremented epoch", () => { + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + // Just verify resetPaged is callable and doesn't throw. + expect(() => { + act(() => { + result.current.resetPaged(); + }); + }).not.toThrow(); + }); + + it("resets all state on sessionId change", () => { + const { result, rerender } = renderHook( + (props) => useLoadMoreMessages(props), + { initialProps: BASE_ARGS }, + ); + + rerender({ + ...BASE_ARGS, + sessionId: "sess-2", + initialOldestSequence: 10, + initialNewestSequence: 59, + initialHasMore: false, + }); + + expect(result.current.pagedMessages).toHaveLength(0); + expect(result.current.hasMore).toBe(false); + expect(result.current.isLoadingMore).toBe(false); + }); + + describe("loadMore — forward pagination", () => { + it("calls getV2GetSession with after_sequence and updates newestSequence", async () => { + const rawMsg = { role: "user", content: "hi", sequence: 50 }; + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [rawMsg], + has_more_messages: true, + newest_sequence: 99, + }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).toHaveBeenCalledWith( + "sess-1", + expect.objectContaining({ after_sequence: 49 }), + ); + expect(result.current.hasMore).toBe(true); + expect(result.current.isLoadingMore).toBe(false); + }); + + it("sets hasMore=false when response has no more messages", async () => { + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ has_more_messages: false }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(result.current.hasMore).toBe(false); + }); + + it("is a no-op when hasMore is false", async () => { + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + initialHasMore: false, + forwardPaginated: true, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).not.toHaveBeenCalled(); + }); + }); + + describe("loadMore — backward pagination", () => { + it("calls getV2GetSession with before_sequence", async () => { + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [{ role: "user", content: "old", sequence: 0 }], + has_more_messages: false, + oldest_sequence: 0, + }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + forwardPaginated: false, + initialOldestSequence: 50, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).toHaveBeenCalledWith( + "sess-1", + expect.objectContaining({ before_sequence: 50 }), + ); + expect(result.current.hasMore).toBe(false); + }); + }); + + describe("loadMore — error handling", () => { + it("does not set hasMore=false on first error", async () => { + mockGetV2GetSession.mockRejectedValueOnce(new Error("network error")); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + await act(async () => { + await result.current.loadMore(); + }); + + // First error — hasMore still true + expect(result.current.hasMore).toBe(true); + expect(result.current.isLoadingMore).toBe(false); + }); + + it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) errors", async () => { + mockGetV2GetSession.mockRejectedValue(new Error("network error")); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + for (let i = 0; i < 3; i++) { + await act(async () => { + await result.current.loadMore(); + }); + // Reset the in-flight guard between calls + await waitFor(() => expect(result.current.isLoadingMore).toBe(false)); + } + + expect(result.current.hasMore).toBe(false); + }); + + it("ignores non-200 response and increments error count", async () => { + mockGetV2GetSession.mockResolvedValueOnce({ status: 500, data: {} }); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + await act(async () => { + await result.current.loadMore(); + }); + + // One error, not yet at threshold — hasMore still true + expect(result.current.hasMore).toBe(true); + expect(result.current.isLoadingMore).toBe(false); + }); + + it("sets hasMore=false after MAX_CONSECUTIVE_ERRORS (3) non-200 responses", async () => { + mockGetV2GetSession.mockResolvedValue({ status: 503, data: {} }); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + for (let i = 0; i < 3; i++) { + await act(async () => { + await result.current.loadMore(); + }); + await waitFor(() => expect(result.current.isLoadingMore).toBe(false)); + } + + expect(result.current.hasMore).toBe(false); + }); + + it("discards in-flight error when epoch changes mid-flight (resetPaged called)", async () => { + let rejectRequest!: (e: Error) => void; + mockGetV2GetSession.mockReturnValueOnce( + new Promise((_, rej) => { + rejectRequest = rej; + }), + ); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + act(() => { + result.current.loadMore(); + }); + + // Reset epoch mid-flight + act(() => { + result.current.resetPaged(); + }); + + // Reject the in-flight request — stale error should be discarded + await act(async () => { + rejectRequest(new Error("network error")); + }); + + // State unchanged: no hasMore=false, no errorCount, isLoadingMore cleared + expect(result.current.hasMore).toBe(false); // false from resetPaged + expect(result.current.isLoadingMore).toBe(false); + }); + }); + + describe("loadMore — forward pagination cursor advancement", () => { + it("advances newestSequence after a successful forward load", async () => { + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [{ role: "user", content: "hi", sequence: 50 }], + has_more_messages: true, + newest_sequence: 99, + }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ ...BASE_ARGS, forwardPaginated: true }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + // A second loadMore should use after_sequence: 99 (advanced cursor) + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).toHaveBeenLastCalledWith( + "sess-1", + expect.objectContaining({ after_sequence: 99 }), + ); + }); + + it("does not regress newestSequence when parent refetches after pages loaded", async () => { + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [{ role: "user", content: "msg", sequence: 50 }], + has_more_messages: true, + newest_sequence: 99, + }), + ); + + const { result, rerender } = renderHook( + (props) => useLoadMoreMessages(props), + { initialProps: { ...BASE_ARGS, forwardPaginated: true } }, + ); + + // Load one page — newestSequence advances to 99 + await act(async () => { + await result.current.loadMore(); + }); + + // Parent refetches with a lower newest_sequence (49) — should NOT regress cursor + rerender({ + ...BASE_ARGS, + forwardPaginated: true, + initialNewestSequence: 49, + }); + + // Next loadMore should still use the advanced cursor (99) + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ has_more_messages: false, newest_sequence: 149 }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).toHaveBeenLastCalledWith( + "sess-1", + expect.objectContaining({ after_sequence: 99 }), + ); + }); + }); + + describe("loadMore — MAX_OLDER_MESSAGES truncation", () => { + it("truncates accumulated messages at MAX_OLDER_MESSAGES (2000)", async () => { + // Single load of 2001 messages exceeds the limit in one shot. + // This avoids relying on cross-render closure staleness: estimatedTotal = + // pagedRawMessages.length (0, fresh) + 2001 = 2001 >= 2000 → hasMore=false. + const args = { ...BASE_ARGS, forwardPaginated: false }; + + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: Array.from({ length: 2001 }, (_, i) => ({ + role: "user", + content: `msg ${i}`, + sequence: i, + })), + has_more_messages: true, + oldest_sequence: 0, + }), + ); + + const { result } = renderHook(() => useLoadMoreMessages(args)); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(result.current.hasMore).toBe(false); + }); + + it("forward truncation keeps first MAX_OLDER_MESSAGES items (not last)", async () => { + // 1990 messages already paged; load 20 more forward — total 2010 > 2000. + // Forward truncation must keep slice(0, 2000), not slice(-2000), + // to preserve the beginning of the conversation. + const forwardNearLimitArgs = { + ...BASE_ARGS, + forwardPaginated: true, + initialNewestSequence: 49, + initialOldestSequence: 0, + initialHasMore: true, + }; + + const { result } = renderHook((props) => useLoadMoreMessages(props), { + initialProps: forwardNearLimitArgs, + }); + + // First load: 1990 messages — advances newestSequence to 2039 + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: Array.from({ length: 1990 }, (_, i) => ({ + role: "assistant", + content: `msg ${i + 50}`, + sequence: i + 50, + })), + has_more_messages: true, + newest_sequence: 2039, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + // Second load: 20 more messages pushes total to 2010 > 2000. + // Truncation keeps seq 50..2049 (2000 items); discards seq 2050..2059 (10 items). + // Even though the server says has_more_messages=false, hasMore stays true + // because there are discarded items that need to be re-fetched. + // The cursor (newestSequence) advances to 2049 — the last kept item's sequence. + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: Array.from({ length: 20 }, (_, i) => ({ + role: "assistant", + content: `msg ${i + 2040}`, + sequence: i + 2040, + })), + has_more_messages: false, + newest_sequence: 2059, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + // Truncation occurred (2010 > 2000): hasMore=true so discarded items can be fetched. + // Cursor advances to last kept item (seq 2049), not the server's newest (2059). + await waitFor(() => expect(result.current.hasMore).toBe(true)); + }); + }); + + describe("loadMore — null cursor guard", () => { + it("is a no-op when newestSequence is null (forwardPaginated=true)", async () => { + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + forwardPaginated: true, + initialNewestSequence: null, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).not.toHaveBeenCalled(); + }); + + it("is a no-op when oldestSequence is null (forwardPaginated=false)", async () => { + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + forwardPaginated: false, + initialOldestSequence: null, + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(mockGetV2GetSession).not.toHaveBeenCalled(); + }); + }); + + describe("pagedMessages — initialPageRawMessages extraToolOutputs", () => { + it("calls extractToolOutputsFromRaw for backward pagination with non-empty initialPageRawMessages", async () => { + const { extractToolOutputsFromRaw } = await import( + "../helpers/convertChatSessionToUiMessages" + ); + + const rawMsg = { role: "user", content: "old", sequence: 0 }; + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [rawMsg], + has_more_messages: false, + oldest_sequence: 0, + }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + forwardPaginated: false, + initialOldestSequence: 50, + initialPageRawMessages: [{ role: "assistant", content: "response" }], + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(extractToolOutputsFromRaw).toHaveBeenCalled(); + }); + + it("does NOT call extractToolOutputsFromRaw for forward pagination", async () => { + const { extractToolOutputsFromRaw } = await import( + "../helpers/convertChatSessionToUiMessages" + ); + + const rawMsg = { role: "assistant", content: "hi", sequence: 50 }; + mockGetV2GetSession.mockResolvedValueOnce( + makeSuccessResponse({ + messages: [rawMsg], + has_more_messages: false, + newest_sequence: 99, + }), + ); + + const { result } = renderHook(() => + useLoadMoreMessages({ + ...BASE_ARGS, + forwardPaginated: true, + initialPageRawMessages: [{ role: "user", content: "hello" }], + }), + ); + + await act(async () => { + await result.current.loadMore(); + }); + + expect(extractToolOutputsFromRaw).not.toHaveBeenCalled(); + }); + }); + + describe("loadMore — epoch / stale-response guard", () => { + it("discards response when epoch changes during flight (resetPaged called)", async () => { + let resolveRequest!: (v: unknown) => void; + mockGetV2GetSession.mockReturnValueOnce( + new Promise((res) => { + resolveRequest = res; + }), + ); + + const { result } = renderHook(() => useLoadMoreMessages(BASE_ARGS)); + + // Start the request without awaiting + act(() => { + result.current.loadMore(); + }); + + // Reset epoch mid-flight + act(() => { + result.current.resetPaged(); + }); + + // Now resolve the in-flight request + await act(async () => { + resolveRequest( + makeSuccessResponse({ messages: [{ role: "user", content: "hi" }] }), + ); + }); + + // Response discarded — pagedMessages stays empty, isLoadingMore stays false + expect(result.current.pagedMessages).toHaveLength(0); + expect(result.current.isLoadingMore).toBe(false); + }); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatContainer/ChatContainer.tsx index 7f3c1d0328..6731057658 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatContainer/ChatContainer.tsx @@ -30,6 +30,7 @@ export interface ChatContainerProps { hasMoreMessages?: boolean; isLoadingMore?: boolean; onLoadMore?: () => void; + forwardPaginated?: boolean; /** Files dropped onto the chat window. */ droppedFiles?: File[]; /** Called after droppedFiles have been consumed by ChatInput. */ @@ -54,6 +55,7 @@ export const ChatContainer = ({ hasMoreMessages, isLoadingMore, onLoadMore, + forwardPaginated, droppedFiles, onDroppedFilesConsumed, historicalDurations, @@ -108,6 +110,7 @@ export const ChatContainer = ({ hasMoreMessages={hasMoreMessages} isLoadingMore={isLoadingMore} onLoadMore={onLoadMore} + forwardPaginated={forwardPaginated} onRetry={handleRetry} historicalDurations={historicalDurations} /> diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx index 44f59fcc39..ae24800142 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx @@ -86,11 +86,11 @@ export function ChatInput({ title: next === "advanced" ? "Switched to Advanced model" - : "Switched to Standard model", + : "Switched to Balanced model", description: next === "advanced" ? "Using the highest-capability model." - : "Using the balanced standard model.", + : "Using the balanced default model.", }); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx index b5a94a3bea..5bac773deb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/__tests__/ChatInput.test.tsx @@ -162,10 +162,15 @@ describe("ChatInput mode toggle", () => { expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking"); }); - it("hides toggle button when streaming", () => { + it("hides toggle buttons when streaming", () => { mockFlagValue = true; render(); - expect(screen.queryByLabelText(/switch to/i)).toBeNull(); + expect( + screen.queryByLabelText(/switch to (fast|extended thinking) mode/i), + ).toBeNull(); + expect( + screen.queryByLabelText(/switch to (advanced|balanced|standard) model/i), + ).toBeNull(); }); it("shows mode toggle when hasSession is true and not streaming", () => { @@ -234,7 +239,7 @@ describe("ChatInput model toggle", () => { mockFlagValue = true; mockCopilotLlmModel = "advanced"; render(); - fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + fireEvent.click(screen.getByLabelText(/switch to balanced model/i)); expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard"); }); @@ -288,10 +293,10 @@ describe("ChatInput model toggle", () => { mockFlagValue = true; mockCopilotLlmModel = "advanced"; render(); - fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + fireEvent.click(screen.getByLabelText(/switch to balanced model/i)); expect(toast).toHaveBeenCalledWith( expect.objectContaining({ - title: expect.stringMatching(/switched to standard model/i), + title: expect.stringMatching(/switched to balanced model/i), }), ); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/DryRunToggleButton.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/DryRunToggleButton.tsx index 36c84d6826..a0b6b5b8f1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/DryRunToggleButton.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/DryRunToggleButton.tsx @@ -2,6 +2,11 @@ import { cn } from "@/lib/utils"; import { Flask } from "@phosphor-icons/react"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; // This button is only rendered on NEW chats (no active session). // Once a session exists, it is hidden — the session's dry_run flag is @@ -14,27 +19,31 @@ interface Props { export function DryRunToggleButton({ isDryRun, onToggle }: Props) { return ( - + + + + + + {isDryRun + ? "Test mode on — new sessions run without performing real actions (click to turn off)." + : "Turn on test mode to try prompts without performing real actions."} + + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModeToggleButton.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModeToggleButton.tsx index 6a3ab0d34d..5636123324 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModeToggleButton.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModeToggleButton.tsx @@ -2,6 +2,11 @@ import { cn } from "@/lib/utils"; import { Brain, Lightning } from "@phosphor-icons/react"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; import type { CopilotMode } from "../../../store"; interface Props { @@ -11,37 +16,42 @@ interface Props { export function ModeToggleButton({ mode, onToggle }: Props) { const isExtended = mode === "extended_thinking"; + + const tooltipText = isExtended + ? "Extended Thinking — deeper reasoning (click to switch to Fast)" + : "Fast mode — quicker responses (click to switch to Thinking)"; + return ( - + + + + + {tooltipText} + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx index cb3bc25f4f..68ec4d5fac 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx @@ -2,6 +2,11 @@ import { cn } from "@/lib/utils"; import { Cpu } from "@phosphor-icons/react"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; import type { CopilotLlmModel } from "../../../store"; interface Props { @@ -12,27 +17,33 @@ interface Props { export function ModelToggleButton({ model, onToggle }: Props) { const isAdvanced = model === "advanced"; return ( - + + + + + + {isAdvanced + ? "Using the highest-capability model (click to switch to Balanced)." + : "Using the balanced default model (click to switch to Advanced)."} + + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/DryRunToggleButton.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/DryRunToggleButton.test.tsx index d8920c8749..f48f8a40c8 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/DryRunToggleButton.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/DryRunToggleButton.test.tsx @@ -1,21 +1,32 @@ -import { render, screen, fireEvent, cleanup } from "@testing-library/react"; +import { + render as rtlRender, + screen, + fireEvent, + cleanup, +} from "@testing-library/react"; import { afterEach, describe, expect, it, vi } from "vitest"; +import type { ReactElement } from "react"; +import { TooltipProvider } from "@/components/ui/tooltip"; import { DryRunToggleButton } from "../DryRunToggleButton"; afterEach(cleanup); +function render(ui: ReactElement) { + return rtlRender({ui}); +} + // DryRunToggleButton only appears on new chats (no active session). // It has no readOnly/isStreaming props — those scenarios are handled by hiding // the button entirely at the ChatInput level when hasSession is true. describe("DryRunToggleButton", () => { - it("shows Test label when isDryRun is true", () => { + it("shows enabled label when isDryRun is true", () => { render(); - expect(screen.getByText("Test")).toBeTruthy(); + expect(screen.getByText("Test mode enabled")).toBeTruthy(); }); - it("shows no text label when isDryRun is false", () => { + it("shows enable label when isDryRun is false", () => { render(); - expect(screen.queryByText("Test")).toBeNull(); + expect(screen.getByText("Enable test mode")).toBeTruthy(); }); it("calls onToggle when clicked", () => { diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx index a17e702f1a..6193eb8694 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx @@ -1,9 +1,20 @@ -import { render, screen, fireEvent, cleanup } from "@testing-library/react"; +import { + render as rtlRender, + screen, + fireEvent, + cleanup, +} from "@testing-library/react"; import { afterEach, describe, expect, it, vi } from "vitest"; +import type { ReactElement } from "react"; +import { TooltipProvider } from "@/components/ui/tooltip"; import { ModelToggleButton } from "../ModelToggleButton"; afterEach(cleanup); +function render(ui: ReactElement) { + return rtlRender({ui}); +} + describe("ModelToggleButton", () => { it("shows no text label when model is standard", () => { render(); @@ -31,7 +42,7 @@ describe("ModelToggleButton", () => { it("sets aria-pressed=true for advanced", () => { render(); - const btn = screen.getByLabelText("Switch to Standard model"); + const btn = screen.getByLabelText("Switch to Balanced model"); expect(btn.getAttribute("aria-pressed")).toBe("true"); }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx index ef2cead564..6b1f22708d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/ChatMessagesContainer.tsx @@ -43,6 +43,10 @@ interface Props { hasMoreMessages?: boolean; isLoadingMore?: boolean; onLoadMore?: () => void; + /** When true the load-more sentinel is placed at the bottom (forward + * pagination for completed sessions). When false it is at the top + * (backward pagination for active sessions). */ + forwardPaginated?: boolean; onRetry?: () => void; historicalDurations?: Map; } @@ -140,11 +144,25 @@ export function LoadMoreSentinel({ isLoading, messageCount, onLoadMore, + rootMargin = "200px 0px 0px 0px", + adjustScroll = true, + forwardPaginated = false, }: { hasMore: boolean; isLoading: boolean; messageCount: number; onLoadMore: () => void; + /** IntersectionObserver rootMargin. Top sentinel uses "200px 0px 0px 0px" + * (pre-trigger when approaching from above); bottom sentinel should use + * "0px 0px 200px 0px" (pre-trigger when approaching from below). */ + rootMargin?: string; + /** Whether to adjust scrollTop after load to preserve visual position. + * True for backward pagination (prepend above); false for forward + * pagination (append below) where no adjustment is needed. */ + adjustScroll?: boolean; + /** When true the button reads "Load newer messages" (forward pagination). + * When false (default) it reads "Load older messages". */ + forwardPaginated?: boolean; }) { const sentinelRef = useRef(null); const onLoadMoreRef = useRef(onLoadMore); @@ -189,11 +207,11 @@ export function LoadMoreSentinel({ if (autoFillRoundsRef.current >= MAX_AUTO_FILL_ROUNDS) return; captureAndLoad(true); }, - { rootMargin: "200px 0px 0px 0px" }, + { rootMargin }, ); observer.observe(sentinelRef.current); return () => observer.disconnect(); - }, [hasMore, isLoading, scrollRef]); + }, [hasMore, isLoading, rootMargin, scrollRef]); // After React commits new DOM nodes (prepended messages), adjust // scrollTop so the user stays at the same visual position. @@ -206,7 +224,9 @@ export function LoadMoreSentinel({ scrollSnapshotRef.current; if (!el || prevHeight === 0) return; const delta = el.scrollHeight - prevHeight; - if (delta > 0) { + // Only restore scroll position for backward pagination (content prepended + // above). Forward pagination appends below — no adjustment needed. + if (adjustScroll && delta > 0) { el.scrollTop = prevTop + delta; } // Reset the auto-fill backoff whenever the container becomes @@ -220,7 +240,7 @@ export function LoadMoreSentinel({ } scrollSnapshotRef.current = { scrollHeight: 0, scrollTop: 0 }; autoTriggeredRef.current = false; - }, [messageCount, scrollRef]); + }, [adjustScroll, messageCount, scrollRef]); return (
captureAndLoad(false)} > - Load older messages + {forwardPaginated ? "Load newer messages" : "Load older messages"} ) )} @@ -256,6 +276,7 @@ export function ChatMessagesContainer({ hasMoreMessages, isLoadingMore, onLoadMore, + forwardPaginated, onRetry, historicalDurations, }: Props) { @@ -334,7 +355,7 @@ export function ChatMessagesContainer({ } > - {hasMoreMessages && onLoadMore && ( + {hasMoreMessages && onLoadMore && !forwardPaginated && ( )} + {hasMoreMessages && onLoadMore && forwardPaginated && ( + + )} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/ChatMessagesContainer.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/ChatMessagesContainer.test.tsx new file mode 100644 index 0000000000..ca7ee0d181 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/ChatMessagesContainer.test.tsx @@ -0,0 +1,173 @@ +import { render, screen, cleanup } from "@/tests/integrations/test-utils"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { ChatMessagesContainer } from "../ChatMessagesContainer"; + +const mockScrollEl = { + scrollHeight: 100, + scrollTop: 0, + clientHeight: 500, +}; + +vi.mock("use-stick-to-bottom", () => ({ + useStickToBottomContext: () => ({ scrollRef: { current: mockScrollEl } }), + Conversation: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + ConversationContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + ConversationScrollButton: () => null, +})); + +vi.mock("@/components/ai-elements/conversation", () => ({ + Conversation: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + ConversationContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + ConversationScrollButton: () => null, +})); + +vi.mock("@/components/ai-elements/message", () => ({ + Message: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + MessageContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), + MessageActions: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})); + +vi.mock("../components/AssistantMessageActions", () => ({ + AssistantMessageActions: () => null, +})); +vi.mock("../components/CopyButton", () => ({ CopyButton: () => null })); +vi.mock("../components/CollapsedToolGroup", () => ({ + CollapsedToolGroup: () => null, +})); +vi.mock("../components/MessageAttachments", () => ({ + MessageAttachments: () => null, +})); +vi.mock("../components/MessagePartRenderer", () => ({ + MessagePartRenderer: () => null, +})); +vi.mock("../components/ReasoningCollapse", () => ({ + ReasoningCollapse: () => null, +})); +vi.mock("../components/ThinkingIndicator", () => ({ + ThinkingIndicator: () => null, +})); +vi.mock("../../JobStatsBar/TurnStatsBar", () => ({ + TurnStatsBar: () => null, +})); +vi.mock("../../JobStatsBar/useElapsedTimer", () => ({ + useElapsedTimer: () => ({ elapsedSeconds: 0 }), +})); +vi.mock("../../CopilotPendingReviews/CopilotPendingReviews", () => ({ + CopilotPendingReviews: () => null, +})); +vi.mock("../helpers", () => ({ + buildRenderSegments: () => [], + getTurnMessages: () => [], + parseSpecialMarkers: () => ({ markerType: null }), + splitReasoningAndResponse: (parts: unknown[]) => ({ + reasoningParts: [], + responseParts: parts, + }), +})); + +type ObserverCallback = (entries: { isIntersecting: boolean }[]) => void; +class MockIntersectionObserver { + static lastCallback: ObserverCallback | null = null; + private callback: ObserverCallback; + constructor(cb: ObserverCallback) { + this.callback = cb; + MockIntersectionObserver.lastCallback = cb; + } + observe() {} + disconnect() {} + unobserve() {} + takeRecords() { + return []; + } + root = null; + rootMargin = ""; + thresholds = []; +} + +const BASE_PROPS = { + messages: [], + status: "ready" as const, + error: undefined, + isLoading: false, + sessionID: "sess-1", + hasMoreMessages: true, + isLoadingMore: false, + onLoadMore: vi.fn(), + onRetry: vi.fn(), +}; + +describe("ChatMessagesContainer", () => { + beforeEach(() => { + mockScrollEl.scrollHeight = 100; + mockScrollEl.scrollTop = 0; + mockScrollEl.clientHeight = 500; + MockIntersectionObserver.lastCallback = null; + vi.stubGlobal("IntersectionObserver", MockIntersectionObserver); + }); + + afterEach(() => { + cleanup(); + vi.unstubAllGlobals(); + }); + + it("renders top sentinel when forwardPaginated is false (backward pagination)", () => { + render(); + expect( + screen.getByRole("button", { name: /load older messages/i }), + ).toBeDefined(); + }); + + it("renders top sentinel when forwardPaginated is undefined (default, backward)", () => { + render(); + expect( + screen.getByRole("button", { name: /load older messages/i }), + ).toBeDefined(); + }); + + it("renders bottom sentinel when forwardPaginated is true (forward pagination)", () => { + render(); + expect( + screen.getByRole("button", { name: /load newer messages/i }), + ).toBeDefined(); + }); + + it("hides sentinel when hasMoreMessages is false", () => { + render( + , + ); + expect( + screen.queryByRole("button", { name: /load older messages/i }), + ).toBeNull(); + }); + + it("hides sentinel when onLoadMore is not provided", () => { + render( + , + ); + expect( + screen.queryByRole("button", { name: /load older messages/i }), + ).toBeNull(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/LoadMoreSentinel.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/LoadMoreSentinel.test.tsx index 3cbf4cbe48..d3f4f08c9e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/LoadMoreSentinel.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatMessagesContainer/__tests__/LoadMoreSentinel.test.tsx @@ -172,6 +172,36 @@ describe("LoadMoreSentinel", () => { expect(mockScrollEl.scrollTop).toBe(200); }); + it("does NOT adjust scroll when adjustScroll=false (forward pagination)", () => { + mockScrollEl.scrollHeight = 100; + mockScrollEl.scrollTop = 50; + const onLoadMore = vi.fn(); + const { rerender } = render( + , + ); + // Fire observer to capture snapshot. + MockIntersectionObserver.lastCallback?.([{ isIntersecting: true }]); + // Simulate DOM growing from appended newer messages (forward load-more). + mockScrollEl.scrollHeight = 300; + rerender( + , + ); + // scrollTop should remain unchanged — no jump for forward pagination. + expect(mockScrollEl.scrollTop).toBe(50); + }); + it("ignores same-frame duplicate triggers until isLoading transitions", () => { const onLoadMore = vi.fn(); const { rerender } = render( diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/EmptySession.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/EmptySession.tsx index 0bd0cb8a5b..933172b880 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/EmptySession.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/EmptySession.tsx @@ -13,6 +13,10 @@ import { getSuggestionThemes, } from "./helpers"; import { SuggestionThemes } from "./components/SuggestionThemes/SuggestionThemes"; +import { PulseChips } from "../PulseChips/PulseChips"; +import { usePulseChips } from "../PulseChips/usePulseChips"; +import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; +import { EditNameDialog } from "./components/EditNameDialog/EditNameDialog"; interface Props { inputLayoutId: string; @@ -34,6 +38,8 @@ export function EmptySession({ }: Props) { const { user } = useSupabase(); const greetingName = getGreetingName(user); + const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING); + const pulseChips = usePulseChips(); const { data: suggestedPromptsResponse, isLoading: isLoadingPrompts } = useGetV2GetSuggestedPrompts({ @@ -75,11 +81,16 @@ export function EmptySession({
Hey, {greetingName} + Tell me about your work — I'll find what to automate. + {isAgentBriefingEnabled && ( + + )} +
+ + + + +
+ setName(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault(); + handleSave(); + } + }} + /> + +
+
+ + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/components/EditNameDialog/__tests__/EditNameDialog.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/components/EditNameDialog/__tests__/EditNameDialog.test.tsx new file mode 100644 index 0000000000..89029f211e --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/EmptySession/components/EditNameDialog/__tests__/EditNameDialog.test.tsx @@ -0,0 +1,135 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { + fireEvent, + render, + screen, + waitFor, +} from "@/tests/integrations/test-utils"; +import { server } from "@/mocks/mock-server"; +import { http, HttpResponse } from "msw"; +import { EditNameDialog } from "../EditNameDialog"; + +const mockToast = vi.hoisted(() => vi.fn()); +const mockRefreshSession = vi.hoisted(() => vi.fn()); + +vi.mock("@/components/molecules/Toast/use-toast", () => ({ + useToast: () => ({ toast: mockToast }), +})); + +vi.mock("@/lib/supabase/hooks/useSupabase", () => ({ + useSupabase: () => ({ + refreshSession: mockRefreshSession, + }), +})); + +function mockUpdateNameSuccess() { + server.use( + http.put("/api/auth/user", () => { + return HttpResponse.json({ user: { id: "u1" } }); + }), + ); +} + +function mockUpdateNameError(message = "Network error") { + server.use( + http.put("/api/auth/user", () => { + return HttpResponse.json({ error: message }, { status: 400 }); + }), + ); +} + +async function openDialogAndGetInput() { + const trigger = screen.getByRole("button"); + fireEvent.click(trigger); + await screen.findAllByLabelText(/display name/i); + const inputs = + document.querySelectorAll("input#display-name"); + return inputs[0]; +} + +function getSaveButton() { + const saves = screen.getAllByRole("button", { name: /save/i }); + return saves[0] as HTMLButtonElement; +} + +describe("EditNameDialog", () => { + beforeEach(() => { + mockToast.mockReset(); + mockRefreshSession.mockReset(); + mockRefreshSession.mockResolvedValue({ user: { id: "u1" } }); + }); + + test("opens dialog with current name prefilled", async () => { + mockUpdateNameSuccess(); + render(); + + const input = await openDialogAndGetInput(); + expect(input.value).toBe("Alice"); + }); + + test("saves name via API route and closes dialog", async () => { + mockUpdateNameSuccess(); + render(); + + const input = await openDialogAndGetInput(); + fireEvent.change(input, { target: { value: "Bob" } }); + fireEvent.click(getSaveButton()); + + await waitFor(() => { + expect(mockRefreshSession).toHaveBeenCalled(); + }); + expect(mockToast).toHaveBeenCalledWith({ title: "Name updated" }); + }); + + test("shows error toast when API returns error", async () => { + mockUpdateNameError("Network error"); + render(); + + const input = await openDialogAndGetInput(); + fireEvent.change(input, { target: { value: "Bob" } }); + fireEvent.click(getSaveButton()); + + await waitFor(() => { + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Failed to update name", + description: "Network error", + variant: "destructive", + }), + ); + }); + expect(mockRefreshSession).not.toHaveBeenCalled(); + }); + + test("shows warning toast when refreshSession returns an error", async () => { + mockUpdateNameSuccess(); + mockRefreshSession.mockResolvedValue({ error: "refresh failed" }); + + render(); + + const input = await openDialogAndGetInput(); + fireEvent.change(input, { target: { value: "Bob" } }); + fireEvent.click(getSaveButton()); + + await waitFor(() => { + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Name saved, but session refresh failed", + description: "refresh failed", + variant: "destructive", + }), + ); + }); + expect(mockToast).not.toHaveBeenCalledWith({ title: "Name updated" }); + }); + + test("disables Save button while input is empty", async () => { + mockUpdateNameSuccess(); + render(); + + const input = await openDialogAndGetInput(); + fireEvent.change(input, { target: { value: " " } }); + + expect(getSaveButton().disabled).toBe(true); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.module.css b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.module.css new file mode 100644 index 0000000000..da221fb7d8 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.module.css @@ -0,0 +1,93 @@ +.glassPanel { + position: relative; + isolation: isolate; +} + +.glassPanel::before { + content: ""; + position: absolute; + inset: 0; + border-radius: inherit; + padding: 1px; + background: conic-gradient( + from var(--border-angle, 0deg), + rgba(129, 120, 228, 0.08), + rgba(129, 120, 228, 0.28), + rgba(168, 130, 255, 0.18), + rgba(129, 120, 228, 0.08), + rgba(99, 102, 241, 0.24), + rgba(129, 120, 228, 0.08) + ); + -webkit-mask: + linear-gradient(#000 0 0) content-box, + linear-gradient(#000 0 0); + mask: + linear-gradient(#000 0 0) content-box, + linear-gradient(#000 0 0); + -webkit-mask-composite: xor; + mask-composite: exclude; + animation: rotate-border 6s linear infinite; + pointer-events: none; + z-index: -1; +} + +@property --border-angle { + syntax: ""; + initial-value: 0deg; + inherits: false; +} + +@keyframes rotate-border { + to { + --border-angle: 360deg; + } +} + +.chip { + overflow: hidden; +} + +@media (hover: hover) { + .chip { + padding-bottom: 0.9rem; + } +} + +@media (hover: none) { + .chip { + padding-bottom: 2.25rem; + } +} + +.chipActions { + position: absolute; + inset-inline: 0; + bottom: 0; + background: rgba(255, 255, 255, 0.95); + backdrop-filter: blur(4px); + -webkit-backdrop-filter: blur(4px); +} + +@media (hover: hover) { + .chipActions { + opacity: 0; + transform: translateY(100%); + transition: + opacity 0.2s ease-out, + transform 0.2s ease-out; + } + + .chip:hover .chipActions { + opacity: 1; + transform: translateY(0); + } + + .chipContent { + transition: filter 0.2s ease-out; + } + + .chip:hover .chipContent { + filter: blur(2px); + opacity: 0.5; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.tsx new file mode 100644 index 0000000000..f369ad0c05 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/PulseChips.tsx @@ -0,0 +1,116 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { + ArrowRightIcon, + EyeIcon, + ChatCircleDotsIcon, +} from "@phosphor-icons/react"; +import NextLink from "next/link"; +import { StatusBadge } from "@/app/(platform)/library/components/StatusBadge/StatusBadge"; +import styles from "./PulseChips.module.css"; +import type { PulseChipData } from "./types"; + +interface Props { + chips: PulseChipData[]; + onChipClick?: (prompt: string) => void; +} + +export function PulseChips({ chips, onChipClick }: Props) { + if (chips.length === 0) return null; + + return ( +
+
+ + What's happening with your agents + + + View all + +
+
+ {chips.map((chip) => ( + + ))} +
+
+ ); +} + +interface ChipProps { + chip: PulseChipData; + onAsk?: (prompt: string) => void; +} + +function PulseChip({ chip, onAsk }: ChipProps) { + function handleAsk() { + const prompt = buildChipPrompt(chip); + onAsk?.(prompt); + } + + return ( +
+
+ {chip.priority === "success" ? ( + + + Completed + + ) : ( + + )} +
+ + {chip.name} + + + {chip.shortMessage} + +
+
+
+ + + See + + +
+
+ ); +} + +function buildChipPrompt(chip: PulseChipData): string { + if (chip.priority === "success") { + return `${chip.name} just finished a run — can you summarize what it did?`; + } + switch (chip.status) { + case "error": + return `What happened with ${chip.name}? It has an error — can you check?`; + case "running": + return `Give me a status update on ${chip.name} — what has it done so far?`; + case "idle": + return `${chip.name} hasn't run recently. Should I keep it or update and re-run it?`; + default: + return `Tell me about ${chip.name} — what's its current status?`; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/__tests__/PulseChips.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/__tests__/PulseChips.test.tsx new file mode 100644 index 0000000000..2496929b58 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/__tests__/PulseChips.test.tsx @@ -0,0 +1,105 @@ +import { describe, expect, test, vi } from "vitest"; +import { render, screen, fireEvent } from "@/tests/integrations/test-utils"; +import { PulseChips } from "../PulseChips"; +import type { PulseChipData } from "../types"; + +function makeChip(overrides: Partial = {}): PulseChipData { + return { + id: "chip-1", + agentID: "agent-1", + name: "Test Agent", + status: "running", + priority: "running", + shortMessage: "Doing work…", + ...overrides, + }; +} + +describe("PulseChips", () => { + test("renders nothing when chips array is empty", () => { + const { container } = render(); + expect(container.innerHTML).toBe(""); + }); + + test("renders chip names and messages", () => { + const chips = [ + makeChip({ id: "1", name: "Alpha Bot", shortMessage: "Running task A" }), + makeChip({ id: "2", name: "Beta Bot", shortMessage: "Running task B" }), + ]; + + render(); + + expect(screen.getByText("Alpha Bot")).toBeDefined(); + expect(screen.getByText("Running task A")).toBeDefined(); + expect(screen.getByText("Beta Bot")).toBeDefined(); + expect(screen.getByText("Running task B")).toBeDefined(); + }); + + test("renders section heading and View all link", () => { + render(); + + expect(screen.getByText("What's happening with your agents")).toBeDefined(); + expect(screen.getByText("View all")).toBeDefined(); + }); + + test("shows Completed badge for success priority chips", () => { + render( + , + ); + + expect(screen.getByText("Completed")).toBeDefined(); + }); + + test("calls onChipClick with generated prompt when Ask is clicked", () => { + const onChipClick = vi.fn(); + render( + , + ); + + fireEvent.click(screen.getByText("Ask")); + + expect(onChipClick).toHaveBeenCalledWith( + "What happened with Error Agent? It has an error — can you check?", + ); + }); + + test("generates success prompt for completed chips", () => { + const onChipClick = vi.fn(); + render( + , + ); + + fireEvent.click(screen.getByText("Ask")); + + expect(onChipClick).toHaveBeenCalledWith( + "Done Agent just finished a run — can you summarize what it did?", + ); + }); + + test("renders See link pointing to agent detail page", () => { + render(); + + const seeLink = screen.getByText("See").closest("a"); + expect(seeLink?.getAttribute("href")).toBe("/library/agents/agent-xyz"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/types.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/types.ts new file mode 100644 index 0000000000..7650afe5ee --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/types.ts @@ -0,0 +1,13 @@ +import type { + AgentStatus, + SitrepPriority, +} from "@/app/(platform)/library/types"; + +export interface PulseChipData { + id: string; + agentID: string; + name: string; + status: AgentStatus; + priority: SitrepPriority; + shortMessage: string; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/usePulseChips.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/usePulseChips.ts new file mode 100644 index 0000000000..f1d56232fe --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/PulseChips/usePulseChips.ts @@ -0,0 +1,23 @@ +"use client"; + +import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents"; +import { useSitrepItems } from "@/app/(platform)/library/components/SitrepItem/useSitrepItems"; +import type { PulseChipData } from "./types"; +import { useMemo } from "react"; + +export function usePulseChips(): PulseChipData[] { + const { agents } = useLibraryAgents(); + + const sitrepItems = useSitrepItems(agents, 5); + + return useMemo(() => { + return sitrepItems.map((item) => ({ + id: item.id, + agentID: item.agentID, + name: item.agentName, + status: item.status, + priority: item.priority, + shortMessage: item.message, + })); + }, [sitrepItems]); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/RateLimitResetDialog/RateLimitResetDialog.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/RateLimitResetDialog/RateLimitResetDialog.tsx index c704a5505d..22e892655a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/RateLimitResetDialog/RateLimitResetDialog.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/RateLimitResetDialog/RateLimitResetDialog.tsx @@ -6,6 +6,9 @@ import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { useRouter } from "next/navigation"; import { useEffect, useRef } from "react"; import { useResetRateLimit } from "../../hooks/useResetRateLimit"; +import { formatCents } from "../usageHelpers"; + +export { formatCents }; interface Props { isOpen: boolean; @@ -18,10 +21,6 @@ interface Props { onCreditChange?: () => void; } -export function formatCents(cents: number): string { - return `$${(cents / 100).toFixed(2)}`; -} - export function RateLimitResetDialog({ isOpen, onClose, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx index fe420d145d..91187816da 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx @@ -1,35 +1,10 @@ import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; import { Button } from "@/components/atoms/Button/Button"; import Link from "next/link"; -import { formatCents } from "../RateLimitResetDialog/RateLimitResetDialog"; +import { formatCents, formatResetTime } from "../usageHelpers"; import { useResetRateLimit } from "../../hooks/useResetRateLimit"; -export function formatResetTime( - resetsAt: Date | string, - now: Date = new Date(), -): string { - const resetDate = - typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt; - const diffMs = resetDate.getTime() - now.getTime(); - if (diffMs <= 0) return "now"; - - const hours = Math.floor(diffMs / (1000 * 60 * 60)); - - // Under 24h: show relative time ("in 4h 23m") - if (hours < 24) { - const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60)); - if (hours > 0) return `in ${hours}h ${minutes}m`; - return `in ${minutes}m`; - } - - // Over 24h: show day and time in local timezone ("Mon 12:00 AM PST") - return resetDate.toLocaleString(undefined, { - weekday: "short", - hour: "numeric", - minute: "2-digit", - timeZoneName: "short", - }); -} +export { formatResetTime }; function UsageBar({ label, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts new file mode 100644 index 0000000000..599442075f --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts @@ -0,0 +1,28 @@ +export function formatCents(cents: number): string { + return `$${(cents / 100).toFixed(2)}`; +} + +export function formatResetTime( + resetsAt: Date | string, + now: Date = new Date(), +): string { + const resetDate = + typeof resetsAt === "string" ? new Date(resetsAt) : resetsAt; + const diffMs = resetDate.getTime() - now.getTime(); + if (diffMs <= 0) return "now"; + + const hours = Math.floor(diffMs / (1000 * 60 * 60)); + + if (hours < 24) { + const minutes = Math.floor((diffMs % (1000 * 60 * 60)) / (1000 * 60)); + if (hours > 0) return `in ${hours}h ${minutes}m`; + return `in ${minutes}m`; + } + + return resetDate.toLocaleString(undefined, { + weekday: "short", + hour: "numeric", + minute: "2-digit", + timeZoneName: "short", + }); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/__tests__/convertChatSessionToUiMessages.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/__tests__/convertChatSessionToUiMessages.test.ts new file mode 100644 index 0000000000..33b2879cc9 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/__tests__/convertChatSessionToUiMessages.test.ts @@ -0,0 +1,59 @@ +import { describe, expect, it } from "vitest"; +import { convertChatSessionMessagesToUiMessages } from "../convertChatSessionToUiMessages"; + +const SESSION_ID = "sess-test"; + +describe("convertChatSessionMessagesToUiMessages", () => { + it("does not drop user messages with null content", () => { + const result = convertChatSessionMessagesToUiMessages( + SESSION_ID, + [{ role: "user", content: null, sequence: 0 }], + { isComplete: true }, + ); + + expect(result.messages).toHaveLength(1); + expect(result.messages[0].role).toBe("user"); + }); + + it("does not drop user messages with empty string content", () => { + const result = convertChatSessionMessagesToUiMessages( + SESSION_ID, + [{ role: "user", content: "", sequence: 0 }], + { isComplete: true }, + ); + + expect(result.messages).toHaveLength(1); + expect(result.messages[0].role).toBe("user"); + }); + + it("still drops non-user messages with null content", () => { + const result = convertChatSessionMessagesToUiMessages( + SESSION_ID, + [{ role: "assistant", content: null, sequence: 0 }], + { isComplete: true }, + ); + + expect(result.messages).toHaveLength(0); + }); + + it("still drops non-user messages with empty string content", () => { + const result = convertChatSessionMessagesToUiMessages( + SESSION_ID, + [{ role: "assistant", content: "", sequence: 0 }], + { isComplete: true }, + ); + + expect(result.messages).toHaveLength(0); + }); + + it("includes user message with normal content", () => { + const result = convertChatSessionMessagesToUiMessages( + SESSION_ID, + [{ role: "user", content: "hello", sequence: 0 }], + { isComplete: true }, + ); + + expect(result.messages).toHaveLength(1); + expect(result.messages[0].role).toBe("user"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/convertChatSessionToUiMessages.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/convertChatSessionToUiMessages.ts index 5021d661f0..10b0ad52c1 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/convertChatSessionToUiMessages.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers/convertChatSessionToUiMessages.ts @@ -253,6 +253,11 @@ export function convertChatSessionMessagesToUiMessages( } } + // User messages must always be rendered, even with empty content, so the + // initial prompt is visible when reloading a session. + if (parts.length === 0 && msg.role === "user") { + parts.push({ type: "text", text: "", state: "done" }); + } if (parts.length === 0) return; // Merge consecutive assistant messages into a single UIMessage diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useChatSession.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useChatSession.ts index b5a02620c2..8357ee8af9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useChatSession.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useChatSession.ts @@ -86,6 +86,16 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) { return sessionQuery.data.data.oldest_sequence ?? null; }, [sessionQuery.data]); + const newestSequence = useMemo(() => { + if (sessionQuery.data?.status !== 200) return null; + return sessionQuery.data.data.newest_sequence ?? null; + }, [sessionQuery.data]); + + const forwardPaginated = useMemo(() => { + if (sessionQuery.data?.status !== 200) return false; + return !!sessionQuery.data.data.forward_paginated; + }, [sessionQuery.data]); + // Memoize so the effect in useCopilotPage doesn't infinite-loop on a new // array reference every render. Re-derives only when query data changes. // When the session is complete (no active stream), mark dangling tool @@ -185,6 +195,8 @@ export function useChatSession({ dryRun = false }: UseChatSessionOptions = {}) { hasActiveStream, hasMoreMessages, oldestSequence, + newestSequence, + forwardPaginated, isLoadingSession: sessionQuery.isLoading, isSessionError: sessionQuery.isError, createSession, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index 9e118c2bbc..3e9be079db 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -56,6 +56,8 @@ export function useCopilotPage() { hasActiveStream, hasMoreMessages, oldestSequence, + newestSequence, + forwardPaginated, isLoadingSession, isSessionError, createSession, @@ -84,18 +86,26 @@ export function useCopilotPage() { copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined, }); - const { olderMessages, hasMore, isLoadingMore, loadMore } = + const { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged } = useLoadMoreMessages({ sessionId, initialOldestSequence: oldestSequence, + initialNewestSequence: newestSequence, initialHasMore: hasMoreMessages, + forwardPaginated, initialPageRawMessages: rawSessionMessages, }); - // Combine older (paginated) messages with current page messages, - // merging consecutive assistant UIMessages at the page boundary so - // reasoning + response parts stay in a single bubble. - const messages = concatWithAssistantMerge(olderMessages, currentMessages); + // Combine paginated messages with current page messages, merging consecutive + // assistant UIMessages at the page boundary so reasoning + response parts + // stay in a single bubble. + // Forward pagination (completed sessions): current page is the beginning, + // paged messages are newer pages appended after. + // Backward pagination (active sessions): paged messages are older history + // prepended before the current page. + const messages = forwardPaginated + ? concatWithAssistantMerge(currentMessages, pagedMessages) + : concatWithAssistantMerge(pagedMessages, currentMessages); useCopilotNotifications(sessionId); @@ -170,6 +180,23 @@ export function useCopilotPage() { } }, [sessionId, pendingMessage, sendMessage]); + // --- Clear backward-paginated messages when session completes --- + // When a session transitions from active (forwardPaginated=false) to complete + // (forwardPaginated=true), any backward-paginated older messages would be + // appended after currentMessages instead of before, causing chronological + // disorder. Reset paged state so the completed session renders cleanly. + const prevForwardPaginatedRef = useRef(forwardPaginated); + useEffect(() => { + if ( + !prevForwardPaginatedRef.current && + forwardPaginated && + pagedMessages.length > 0 + ) { + resetPaged(); + } + prevForwardPaginatedRef.current = forwardPaginated; + }, [forwardPaginated, pagedMessages.length, resetPaged]); + // --- Extract prompt from URL hash on mount (e.g. /copilot#prompt=Hello) --- useWorkflowImportAutoSubmit({ createSession, @@ -251,6 +278,15 @@ export function useCopilotPage() { isUserStoppingRef.current = false; if (sessionId) { + // When continuing a completed session that had forward-paginated history + // loaded, the paged messages would appear in wrong position relative to + // the new streaming turn (pagedMessages are newer pages, so they'd end + // up after the streaming turn). Reset paged state so ordering is correct + // during streaming; the user can reload history afterward if needed. + if (forwardPaginated && pagedMessages.length > 0) { + resetPaged(); + } + if (files && files.length > 0) { setIsUploadingFiles(true); try { @@ -397,6 +433,7 @@ export function useCopilotPage() { hasMoreMessages: hasMore, isLoadingMore, loadMore, + forwardPaginated, // Mobile drawer isMobile, isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useLoadMoreMessages.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useLoadMoreMessages.ts index 313b2d5fb8..7c3f1b7c24 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useLoadMoreMessages.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useLoadMoreMessages.ts @@ -9,7 +9,11 @@ import { interface UseLoadMoreMessagesArgs { sessionId: string | null; initialOldestSequence: number | null; + initialNewestSequence: number | null; initialHasMore: boolean; + /** True when the initial page was loaded from sequence 0 forward (completed + * sessions). False when loaded newest-first (active sessions). */ + forwardPaginated: boolean; /** Raw messages from the initial page, used for cross-page tool output matching. */ initialPageRawMessages: unknown[]; } @@ -20,16 +24,21 @@ const MAX_OLDER_MESSAGES = 2000; export function useLoadMoreMessages({ sessionId, initialOldestSequence, + initialNewestSequence, initialHasMore, + forwardPaginated, initialPageRawMessages, }: UseLoadMoreMessagesArgs) { - // Store accumulated raw messages from all older pages (in ascending order). + // Accumulated raw messages from all extra pages (ascending order). // Re-converting them all together ensures tool outputs are matched across // inter-page boundaries. - const [olderRawMessages, setOlderRawMessages] = useState([]); + const [pagedRawMessages, setPagedRawMessages] = useState([]); const [oldestSequence, setOldestSequence] = useState( initialOldestSequence, ); + const [newestSequence, setNewestSequence] = useState( + initialNewestSequence, + ); const [hasMore, setHasMore] = useState(initialHasMore); const [isLoadingMore, setIsLoadingMore] = useState(false); const isLoadingMoreRef = useRef(false); @@ -46,7 +55,7 @@ export function useLoadMoreMessages({ // The parent's `initialOldestSequence` drifts forward every time the // session query refetches (e.g. after a stream completes — see // `useCopilotStream` invalidation on `streaming → ready`). If we - // wiped `olderRawMessages` every time that happened, users who had + // wiped `pagedRawMessages` every time that happened, users who had // scrolled back would lose their loaded history on each new turn and // subsequent `loadMore` calls would fetch messages that overlap with // the AI SDK's retained state in `currentMessages`, producing visible @@ -63,8 +72,9 @@ export function useLoadMoreMessages({ // Session changed — full reset prevSessionIdRef.current = sessionId; prevInitialOldestRef.current = initialOldestSequence; - setOlderRawMessages([]); + setPagedRawMessages([]); setOldestSequence(initialOldestSequence); + setNewestSequence(initialNewestSequence); setHasMore(initialHasMore); setIsLoadingMore(false); isLoadingMoreRef.current = false; @@ -75,49 +85,64 @@ export function useLoadMoreMessages({ prevInitialOldestRef.current = initialOldestSequence; - // If we haven't paged back yet, mirror the parent so the first + // If we haven't paged yet, mirror the parent so the first // `loadMore` starts from the correct cursor. - if (olderRawMessages.length === 0) { + // + // When paged messages exist (pagedRawMessages.length > 0) we intentionally + // do NOT update `hasMore` or `newestSequence` from the parent. A parent + // refetch (e.g. after a new turn completes) may carry a fresh + // `initialHasMore=true` or a larger `initialNewestSequence`, but those + // reflect the *initial* page window, not the forward-paged window we have + // already advanced into. Overwriting the local cursor here would cause the + // next `loadMore` to re-fetch pages we already have. The local cursor is + // advanced correctly inside `loadMore` itself via `setNewestSequence`. + if (pagedRawMessages.length === 0) { setOldestSequence(initialOldestSequence); + // Only regress the forward cursor if we haven't paged ahead yet — + // otherwise a parent refetch would reset a cursor we already advanced. + setNewestSequence((prev) => + prev !== null && prev > (initialNewestSequence ?? -1) + ? prev + : initialNewestSequence, + ); setHasMore(initialHasMore); } - }, [sessionId, initialOldestSequence, initialHasMore]); + }, [sessionId, initialOldestSequence, initialNewestSequence, initialHasMore]); // Convert all accumulated raw messages in one pass so tool outputs - // are matched across inter-page boundaries. Initial page tool outputs - // are included via extraToolOutputs to handle the boundary between - // the last older page and the initial/streaming page. - const olderMessages: UIMessage[] = + // are matched across inter-page boundaries. + // For backward pagination only: include initial page tool outputs so older + // paged pages can match tool calls whose outputs landed in the initial page. + // For forward pagination this is unnecessary — tool calls in newer paged + // pages cannot have their outputs in the older initial page. + const pagedMessages: UIMessage[] = useMemo(() => { - if (!sessionId || olderRawMessages.length === 0) return []; + if (!sessionId || pagedRawMessages.length === 0) return []; const extraToolOutputs = - initialPageRawMessages.length > 0 + !forwardPaginated && initialPageRawMessages.length > 0 ? extractToolOutputsFromRaw(initialPageRawMessages) : undefined; return convertChatSessionMessagesToUiMessages( sessionId, - olderRawMessages, + pagedRawMessages, { isComplete: true, extraToolOutputs }, ).messages; - }, [sessionId, olderRawMessages, initialPageRawMessages]); + }, [sessionId, pagedRawMessages, initialPageRawMessages, forwardPaginated]); async function loadMore() { - if ( - !sessionId || - !hasMore || - isLoadingMoreRef.current || - oldestSequence === null - ) - return; + if (!sessionId || !hasMore || isLoadingMoreRef.current) return; + + const cursor = forwardPaginated ? newestSequence : oldestSequence; + if (cursor === null) return; const requestEpoch = epochRef.current; isLoadingMoreRef.current = true; setIsLoadingMore(true); try { - const response = await getV2GetSession(sessionId, { - limit: 50, - before_sequence: oldestSequence, - }); + const params = forwardPaginated + ? { limit: 50, after_sequence: cursor } + : { limit: 50, before_sequence: cursor }; + const response = await getV2GetSession(sessionId, params); // Discard response if session/pagination was reset while awaiting if (epochRef.current !== requestEpoch) return; @@ -136,18 +161,66 @@ export function useLoadMoreMessages({ consecutiveErrorsRef.current = 0; const newRaw = (response.data.messages ?? []) as unknown[]; - setOlderRawMessages((prev) => { - const merged = [...newRaw, ...prev]; + // Estimate total after merge using the closure-captured pagedRawMessages.length. + // This is a safe approximation: worst case it's one page stale (one extra load + // allowed), but it avoids the React-18-batching pitfall where a functional + // updater's mutations are not visible until the next render. + const estimatedTotal = pagedRawMessages.length + newRaw.length; + setPagedRawMessages((prev) => { + // Forward: append to end. Backward: prepend to start. + const merged = forwardPaginated + ? [...prev, ...newRaw] + : [...newRaw, ...prev]; if (merged.length > MAX_OLDER_MESSAGES) { - return merged.slice(merged.length - MAX_OLDER_MESSAGES); + // Backward: discard the oldest (front) items — user has scrolled far + // back and we shed the furthest history. + // Forward: discard the newest (tail) items — we only ever fetch + // forward, so the tail is the most recently appended page; shedding + // it means the sentinel stalls, which is safer than discarding the + // beginning of the conversation the user is here to read. + return forwardPaginated + ? merged.slice(0, MAX_OLDER_MESSAGES) + : merged.slice(merged.length - MAX_OLDER_MESSAGES); } return merged; }); - setOldestSequence(response.data.oldest_sequence ?? null); - if (newRaw.length + olderRawMessages.length >= MAX_OLDER_MESSAGES) { - setHasMore(false); + + if (forwardPaginated) { + const willTruncateForward = estimatedTotal > MAX_OLDER_MESSAGES; + if (willTruncateForward) { + // Truncation shed the newest tail. Advance the cursor to the last KEPT + // item's sequence so the sentinel re-fetches the discarded items next + // time rather than jumping past them. + // lastKeptIdx: index within newRaw of the last item that survives. + // prev contributes pagedRawMessages.length items; total kept = MAX. + const lastKeptIdx = MAX_OLDER_MESSAGES - 1 - pagedRawMessages.length; + if (lastKeptIdx >= 0 && lastKeptIdx < newRaw.length) { + const lastKeptMsg = newRaw[lastKeptIdx] as { sequence?: number }; + if (typeof lastKeptMsg?.sequence === "number") { + setNewestSequence(lastKeptMsg.sequence); + setHasMore(true); // Discarded items still exist — keep sentinel active + } else { + // Sequence unavailable — fall back; truncated items will be lost + setNewestSequence(response.data.newest_sequence ?? null); + setHasMore(!!response.data.has_more_messages); + } + } else { + // All of newRaw was dropped (already at MAX_OLDER_MESSAGES cap). + // Stop to avoid an infinite re-fetch loop at the display cap. + setHasMore(false); + } + } else { + setNewestSequence(response.data.newest_sequence ?? null); + setHasMore(!!response.data.has_more_messages); + } } else { - setHasMore(!!response.data.has_more_messages); + setOldestSequence(response.data.oldest_sequence ?? null); + if (estimatedTotal >= MAX_OLDER_MESSAGES) { + // Backward: accumulated MAX_OLDER_MESSAGES — stop to avoid unbounded memory. + setHasMore(false); + } else { + setHasMore(!!response.data.has_more_messages); + } } } catch (error) { if (epochRef.current !== requestEpoch) return; @@ -164,5 +237,22 @@ export function useLoadMoreMessages({ } } - return { olderMessages, hasMore, isLoadingMore, loadMore }; + function resetPaged() { + setPagedRawMessages([]); + setOldestSequence(initialOldestSequence); + setNewestSequence(initialNewestSequence); + // Set hasMore=false during the session-transition window so no loadMore + // fires with forward pagination (after_sequence) on the now-active session. + // The useEffect will restore hasMore from the parent after the refetch + // completes and forwardPaginated switches to false. + setHasMore(false); + // Clear the loading state so the spinner doesn't stay stuck if a loadMore + // was in flight when resetPaged was called. + setIsLoadingMore(false); + isLoadingMoreRef.current = false; + consecutiveErrorsRef.current = 0; + epochRef.current += 1; + } + + return { pagedMessages, hasMore, isLoadingMore, loadMore, resetPaged }; } diff --git a/autogpt_platform/frontend/src/app/(platform)/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/layout.tsx index 048110f8b2..0d72326e17 100644 --- a/autogpt_platform/frontend/src/app/(platform)/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/layout.tsx @@ -2,14 +2,17 @@ import { Navbar } from "@/components/layout/Navbar/Navbar"; import { NetworkStatusMonitor } from "@/services/network-status/NetworkStatusMonitor"; import { ReactNode } from "react"; import { AdminImpersonationBanner } from "./admin/components/AdminImpersonationBanner"; +import { AutoPilotBridgeProvider } from "@/contexts/AutoPilotBridgeContext"; export default function PlatformLayout({ children }: { children: ReactNode }) { return ( -
- - - -
{children}
-
+ +
+ + + +
{children}
+
+
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/__tests__/main.test.tsx b/autogpt_platform/frontend/src/app/(platform)/library/__tests__/main.test.tsx index 8d7960dc9b..6f6d7f3794 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/__tests__/main.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/__tests__/main.test.tsx @@ -137,8 +137,10 @@ describe("LibraryPage", () => { user_id: "test-user", name: "Work Agents", agent_count: 3, + subfolder_count: 0, color: null, icon: null, + parent_id: null, created_at: new Date(), updated_at: new Date(), }, @@ -147,8 +149,10 @@ describe("LibraryPage", () => { user_id: "test-user", name: "Personal", agent_count: 1, + subfolder_count: 0, color: null, icon: null, + parent_id: null, created_at: new Date(), updated_at: new Date(), }, @@ -158,12 +162,14 @@ describe("LibraryPage", () => { render(); + await waitForAgentsToLoad(); + expect(await screen.findByText("Work Agents")).toBeDefined(); expect(screen.getByText("Personal")).toBeDefined(); expect(screen.getAllByTestId("library-folder")).toHaveLength(2); }); - test("shows See runs link on agent card", async () => { + test("shows See tasks link on agent card", async () => { setupHandlers({ agents: [makeAgent({ name: "Linked Agent", can_access_graph: true })], }); @@ -172,7 +178,7 @@ describe("LibraryPage", () => { await screen.findByText("Linked Agent"); - const runLinks = screen.getAllByText("See runs"); + const runLinks = screen.getAllByText("See tasks"); expect(runLinks.length).toBeGreaterThan(0); }); @@ -190,7 +196,7 @@ describe("LibraryPage", () => { expect(importButtons.length).toBeGreaterThan(0); }); - test("renders Jump Back In when there is an active execution", async () => { + test("renders running agent card when execution is active", async () => { const agent = makeAgent({ id: "lib-1", graph_id: "g-1", @@ -218,6 +224,6 @@ describe("LibraryPage", () => { render(); - expect(await screen.findByText("Jump Back In")).toBeDefined(); + expect(await screen.findByText("Running Agent")).toBeDefined(); }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.module.css b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.module.css new file mode 100644 index 0000000000..8c2bd69313 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.module.css @@ -0,0 +1,44 @@ +.glassPanel { + position: relative; + isolation: isolate; +} + +.glassPanel::before { + content: ""; + position: absolute; + inset: 0; + border-radius: inherit; + padding: 1px; + background: conic-gradient( + from var(--border-angle, 0deg), + rgba(129, 120, 228, 0.04), + rgba(129, 120, 228, 0.14), + rgba(168, 130, 255, 0.09), + rgba(129, 120, 228, 0.04), + rgba(99, 102, 241, 0.12), + rgba(129, 120, 228, 0.04) + ); + -webkit-mask: + linear-gradient(#000 0 0) content-box, + linear-gradient(#000 0 0); + mask: + linear-gradient(#000 0 0) content-box, + linear-gradient(#000 0 0); + -webkit-mask-composite: xor; + mask-composite: exclude; + animation: rotate-border 6s linear infinite; + pointer-events: none; + z-index: -1; +} + +@property --border-angle { + syntax: ""; + initial-value: 0deg; + inherits: false; +} + +@keyframes rotate-border { + to { + --border-angle: 360deg; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.tsx new file mode 100644 index 0000000000..82bf930b7b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/AgentBriefingPanel.tsx @@ -0,0 +1,36 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { useState } from "react"; +import type { FleetSummary, AgentStatusFilter } from "../../types"; +import { BriefingTabContent } from "./BriefingTabContent"; +import { StatsGrid } from "./StatsGrid"; +import styles from "./AgentBriefingPanel.module.css"; + +interface Props { + summary: FleetSummary; + agents: LibraryAgent[]; +} + +export function AgentBriefingPanel({ summary, agents }: Props) { + const [userTab, setUserTab] = useState(null); + const activeTab: AgentStatusFilter = + userTab ?? (summary.running > 0 ? "running" : "all"); + + return ( +
+ Agent Briefing +
+ + +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx new file mode 100644 index 0000000000..5d4df627d9 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx @@ -0,0 +1,347 @@ +"use client"; + +import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; +import { + formatResetTime, + formatCents, +} from "@/app/(platform)/copilot/components/usageHelpers"; +import { useResetRateLimit } from "@/app/(platform)/copilot/hooks/useResetRateLimit"; +import { Button } from "@/components/atoms/Button/Button"; +import { Badge } from "@/components/atoms/Badge/Badge"; +import useCredits from "@/hooks/useCredits"; +import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; +import { useSitrepItems } from "../SitrepItem/useSitrepItems"; +import { SitrepItem } from "../SitrepItem/SitrepItem"; +import { useAgentStatusMap } from "../../hooks/useAgentStatus"; +import type { AgentStatusFilter } from "../../types"; +import { Text } from "@/components/atoms/Text/Text"; +import Link from "next/link"; +import { useState } from "react"; + +interface Props { + activeTab: AgentStatusFilter; + agents: LibraryAgent[]; +} + +export function BriefingTabContent({ activeTab, agents }: Props) { + if (activeTab === "all") { + return ; + } + + if ( + activeTab === "running" || + activeTab === "attention" || + activeTab === "completed" + ) { + return ; + } + + return ; +} + +function UsageSection() { + const { data: usage } = useGetV2GetCopilotUsage({ + query: { + select: (res) => res.data as CoPilotUsageStatus, + refetchInterval: 30000, + staleTime: 10000, + }, + }); + + const isBillingEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT); + const { credits, fetchCredits } = useCredits({ fetchInitialCredits: true }); + const resetCost = usage?.reset_cost; + const hasInsufficientCredits = + credits !== null && resetCost != null && credits < resetCost; + + if (!usage?.daily || !usage?.weekly) return null; + + return ( +
+
+ + Usage limits + + {usage.tier && ( + + {usage.tier.charAt(0) + usage.tier.slice(1).toLowerCase()} plan + + )} +
+ {isBillingEnabled && ( + + Manage billing + + )} +
+
+ {usage.daily.limit > 0 && ( + + )} + {usage.weekly.limit > 0 && ( + + )} +
+ +
+ ); +} + +const MAX_VISIBLE = 6; + +function ExecutionListSection({ + activeTab, + agents, +}: { + activeTab: AgentStatusFilter; + agents: LibraryAgent[]; +}) { + const allItems = useSitrepItems(agents, 50); + const [showAll, setShowAll] = useState(false); + + const filtered = allItems.filter((item) => { + if (activeTab === "running") return item.priority === "running"; + if (activeTab === "attention") return item.priority === "error"; + if (activeTab === "completed") return item.priority === "success"; + return false; + }); + + if (filtered.length === 0) { + return ; + } + + const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE); + const hasMore = filtered.length > MAX_VISIBLE; + + return ( +
+
+ {visible.map((item) => ( + + ))} +
+ {hasMore && ( +
+ +
+ )} +
+ ); +} + +const TAB_STATUS_LABEL: Record = { + listening: "Waiting for trigger event", + scheduled: "Has a scheduled run", + idle: "No recent activity", +}; + +function AgentListSection({ + activeTab, + agents, +}: { + activeTab: AgentStatusFilter; + agents: LibraryAgent[]; +}) { + const [showAll, setShowAll] = useState(false); + const statusMap = useAgentStatusMap(agents); + + const filtered = agents.filter((agent) => { + const status = statusMap.get(agent.graph_id)?.status; + if (activeTab === "listening") return status === "listening"; + if (activeTab === "scheduled") return status === "scheduled"; + if (activeTab === "idle") return status === "idle"; + return false; + }); + + if (filtered.length === 0) { + return ; + } + + const status = + activeTab === "listening" + ? ("listening" as const) + : activeTab === "scheduled" + ? ("scheduled" as const) + : ("idle" as const); + + const visible = showAll ? filtered : filtered.slice(0, MAX_VISIBLE); + const hasMore = filtered.length > MAX_VISIBLE; + + return ( +
+
+ {visible.map((agent) => ( + + ))} +
+ {hasMore && ( +
+ +
+ )} +
+ ); +} + +function UsageFooter({ + usage, + hasInsufficientCredits, + onCreditChange, +}: { + usage: CoPilotUsageStatus; + hasInsufficientCredits: boolean; + onCreditChange?: () => void; +}) { + const isDailyExhausted = + usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit; + const isWeeklyExhausted = + usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit; + const resetCost = usage.reset_cost ?? 0; + const { resetUsage, isPending } = useResetRateLimit({ onCreditChange }); + + const showReset = + isDailyExhausted && + !isWeeklyExhausted && + resetCost > 0 && + !hasInsufficientCredits; + + const showAddCredits = + isDailyExhausted && !isWeeklyExhausted && hasInsufficientCredits; + + if (!showReset && !showAddCredits) return null; + + return ( +
+ {showReset && ( + + )} + {showAddCredits && ( + + Add credits to reset + + )} +
+ ); +} + +function UsageMeter({ + label, + used, + limit, + resetsAt, +}: { + label: string; + used: number; + limit: number; + resetsAt: Date | string; +}) { + if (limit <= 0) return null; + + const rawPercent = (used / limit) * 100; + const percent = Math.min(100, Math.round(rawPercent)); + const isHigh = percent >= 80; + const percentLabel = + used > 0 && percent === 0 ? "<1% used" : `${percent}% used`; + + return ( +
+
+ + {label} + + + {percentLabel} + +
+
+
0 ? 1 : 0, percent)}%` }} + /> +
+
+ + {used.toLocaleString()} / {limit.toLocaleString()} + + + Resets {formatResetTime(resetsAt)} + +
+
+ ); +} + +const EMPTY_MESSAGES: Record = { + running: "No agents running right now", + attention: "No agents that need attention", + completed: "No recently completed runs", + listening: "No agents listening for events", + scheduled: "No agents with scheduled runs", + idle: "No idle agents", +}; + +function EmptyMessage({ tab }: { tab: AgentStatusFilter }) { + return ( +
+ + {EMPTY_MESSAGES[tab] ?? "No agents in this category"} + +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/StatsGrid.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/StatsGrid.tsx new file mode 100644 index 0000000000..d887776b22 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/StatsGrid.tsx @@ -0,0 +1,102 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { OverflowText } from "@/components/atoms/OverflowText/OverflowText"; +import { Emoji } from "@/components/atoms/Emoji/Emoji"; +import { cn } from "@/lib/utils"; +import type { FleetSummary, AgentStatusFilter } from "../../types"; + +interface Props { + summary: FleetSummary; + activeTab: AgentStatusFilter; + onTabChange: (tab: AgentStatusFilter) => void; +} + +const TILES: { + label: string; + key: keyof FleetSummary; + format?: (v: number) => string; + filter: AgentStatusFilter; + emoji: string; + color: string; +}[] = [ + { + label: "Spent this month", + key: "monthlySpend", + format: (v) => `$${v.toLocaleString()}`, + filter: "all", + emoji: "💵", + color: "text-zinc-700", + }, + { + label: "Running now", + key: "running", + filter: "running", + emoji: "🚩", + color: "text-blue-600", + }, + { + label: "Recently completed", + key: "completed", + filter: "completed", + emoji: "🗃️", + color: "text-green-600", + }, + { + label: "Needs attention", + key: "error", + filter: "attention", + emoji: "⚠️", + color: "text-red-500", + }, + { + label: "Scheduled", + key: "scheduled", + filter: "scheduled", + emoji: "📅", + color: "text-yellow-600", + }, + { + label: "Idle", + key: "idle", + filter: "idle", + emoji: "💤", + color: "text-zinc-400", + }, +]; + +export function StatsGrid({ summary, activeTab, onTabChange }: Props) { + return ( +
+ {TILES.map((tile) => { + const rawValue = summary[tile.key]; + const value = tile.format ? tile.format(rawValue) : rawValue; + const isActive = activeTab === tile.filter; + + return ( + + ); + })} +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentFilterMenu/AgentFilterMenu.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentFilterMenu/AgentFilterMenu.tsx new file mode 100644 index 0000000000..b247c0dcf3 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentFilterMenu/AgentFilterMenu.tsx @@ -0,0 +1,52 @@ +"use client"; + +import type { SelectOption } from "@/components/atoms/Select/Select"; +import { Select } from "@/components/atoms/Select/Select"; +import { FunnelIcon } from "@phosphor-icons/react"; +import type { AgentStatusFilter, FleetSummary } from "../../types"; + +interface Props { + value: AgentStatusFilter; + onChange: (value: AgentStatusFilter) => void; + summary: FleetSummary; +} + +function buildOptions(summary: FleetSummary): SelectOption[] { + return [ + { value: "all", label: "All Agents" }, + { value: "running", label: `Running (${summary.running})` }, + { value: "attention", label: `Needs Attention (${summary.error})` }, + { value: "listening", label: `Listening (${summary.listening})` }, + { value: "scheduled", label: `Scheduled (${summary.scheduled})` }, + { value: "idle", label: `Idle / Stale (${summary.idle})` }, + { value: "healthy", label: "Healthy" }, + ]; +} + +export function AgentFilterMenu({ value, onChange, summary }: Props) { + function handleChange(val: string) { + onChange(val as AgentStatusFilter); + } + + const options = buildOptions(summary); + + return ( +
+ + filter + + + - + diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySubSection/LibrarySubSection.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySubSection/LibrarySubSection.tsx index 32169cf441..3a4475d55d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySubSection/LibrarySubSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/LibrarySubSection/LibrarySubSection.tsx @@ -6,9 +6,10 @@ import { } from "@/components/molecules/TabsLine/TabsLine"; import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort"; import { useFavoriteAnimation } from "../../context/FavoriteAnimationContext"; -import { LibraryTab } from "../../types"; +import type { LibraryTab, AgentStatusFilter, FleetSummary } from "../../types"; import LibraryFolderCreationDialog from "../LibraryFolderCreationDialog/LibraryFolderCreationDialog"; import { LibrarySortMenu } from "../LibrarySortMenu/LibrarySortMenu"; +import { AgentFilterMenu } from "../AgentFilterMenu/AgentFilterMenu"; interface Props { tabs: LibraryTab[]; @@ -17,6 +18,9 @@ interface Props { allCount: number; favoritesCount: number; setLibrarySort: (value: LibraryAgentSort) => void; + statusFilter?: AgentStatusFilter; + onStatusFilterChange?: (filter: AgentStatusFilter) => void; + fleetSummary?: FleetSummary; } export function LibrarySubSection({ @@ -26,6 +30,9 @@ export function LibrarySubSection({ allCount, favoritesCount, setLibrarySort, + statusFilter = "all", + onStatusFilterChange, + fleetSummary, }: Props) { const { registerFavoritesTabRef } = useFavoriteAnimation(); const favoritesRef = useRef(null); @@ -68,8 +75,15 @@ export function LibrarySubSection({ ))} -
+
+ {fleetSummary && onStatusFilterChange && ( + + )}
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.module.css b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.module.css new file mode 100644 index 0000000000..56d9944327 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.module.css @@ -0,0 +1,17 @@ +.spinner { + aspect-ratio: 1; + border-radius: 50%; + background: + radial-gradient(farthest-side, currentColor 94%, #0000) top/3px 3px + no-repeat, + conic-gradient(#0000 30%, currentColor); + -webkit-mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0); + mask: radial-gradient(farthest-side, #0000 calc(100% - 3px), #000 0); + animation: spin 1s infinite linear; +} + +@keyframes spin { + 100% { + transform: rotate(1turn); + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.tsx new file mode 100644 index 0000000000..3277b06716 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepItem.tsx @@ -0,0 +1,172 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { + WarningCircleIcon, + ClockCountdownIcon, + CheckCircleIcon, + ChatCircleDotsIcon, + EarIcon, + CalendarDotsIcon, + MoonIcon, + EyeIcon, +} from "@phosphor-icons/react"; +import NextLink from "next/link"; +import { cn } from "@/lib/utils"; +import { useRouter } from "next/navigation"; +import type { SitrepItemData, SitrepPriority } from "../../types"; +import { ContextualActionButton } from "../ContextualActionButton/ContextualActionButton"; +import styles from "./SitrepItem.module.css"; + +interface Props { + item: SitrepItemData; +} + +const PRIORITY_CONFIG: Record< + SitrepPriority, + { + icon?: typeof WarningCircleIcon; + color: string; + bg: string; + cssSpinner?: boolean; + } +> = { + error: { + icon: WarningCircleIcon, + color: "text-red-500", + bg: "bg-red-50", + }, + running: { + color: "text-zinc-800", + bg: "", + cssSpinner: true, + }, + stale: { + icon: ClockCountdownIcon, + color: "text-yellow-600", + bg: "bg-yellow-50", + }, + success: { + icon: CheckCircleIcon, + color: "text-green-600", + bg: "bg-green-50", + }, + listening: { + icon: EarIcon, + color: "text-purple-500", + bg: "bg-purple-50", + }, + scheduled: { + icon: CalendarDotsIcon, + color: "text-yellow-600", + bg: "bg-yellow-50", + }, + idle: { + icon: MoonIcon, + color: "text-zinc-400", + bg: "bg-zinc-100", + }, +}; + +export function SitrepItem({ item }: Props) { + const config = PRIORITY_CONFIG[item.priority]; + const router = useRouter(); + + function handleAskAutoPilot() { + const prompt = buildAutoPilotPrompt(item); + const encoded = encodeURIComponent(prompt); + router.push(`/copilot?autosubmit=true#prompt=${encoded}`); + } + + return ( +
+
+ {item.agentImageUrl ? ( + {item.agentName} + ) : ( +
+ {config.cssSpinner ? ( +
+ ) : ( + config.icon && ( + + ) + )} +
+ )} + +
+ + {item.agentName} + + + {item.message} + +
+
+ +
+ {item.priority === "success" ? ( + + + See task + + ) : ( + + )} + +
+
+ ); +} + +function buildAutoPilotPrompt(item: SitrepItemData): string { + switch (item.priority) { + case "error": + return `What happened with ${item.agentName}? It says "${item.message}" — can you check the logs and tell me what to fix?`; + case "running": + return `Give me a status update on the ${item.agentName} run — what has it found so far?`; + case "stale": + return `${item.agentName} hasn't run recently. Should I keep it or update and re-run it?`; + case "success": + return `Show me what ${item.agentName} found in its last run — summarize the results and any key takeaways.`; + case "listening": + return `What is ${item.agentName} listening for? Give me a summary of its trigger configuration.`; + case "scheduled": + return `When is ${item.agentName} scheduled to run next?`; + case "idle": + return `${item.agentName} has been idle. Should I keep it or update and re-run it?`; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepList.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepList.tsx new file mode 100644 index 0000000000..5ebf5bfd94 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/SitrepList.tsx @@ -0,0 +1,34 @@ +"use client"; + +import { Text } from "@/components/atoms/Text/Text"; +import { ClockCounterClockwise } from "@phosphor-icons/react"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { useSitrepItems } from "./useSitrepItems"; +import { SitrepItem } from "./SitrepItem"; + +interface Props { + agents: LibraryAgent[]; + maxItems?: number; +} + +export function SitrepList({ agents, maxItems = 10 }: Props) { + const items = useSitrepItems(agents, maxItems); + + if (items.length === 0) return null; + + return ( +
+
+ + + Recent tasks + +
+
+ {items.map((item) => ( + + ))} +
+
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/useSitrepItems.ts b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/useSitrepItems.ts new file mode 100644 index 0000000000..2b4a1deb8b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/SitrepItem/useSitrepItems.ts @@ -0,0 +1,133 @@ +"use client"; + +import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; +import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { okData } from "@/app/api/helpers"; +import { useMemo } from "react"; +import type { SitrepItemData, SitrepPriority } from "../../types"; +import { + isActive, + isFailed, + toEndTime, + endedAfter, + runningMessage, + SEVENTY_TWO_HOURS_MS, +} from "../../hooks/executionHelpers"; + +export function useSitrepItems( + agents: LibraryAgent[], + maxItems: number, +): SitrepItemData[] { + const { data: executions } = useGetV1ListAllExecutions({ + query: { select: okData }, + }); + + return useMemo(() => { + if (!executions || agents.length === 0) return []; + + const graphIdToAgent = new Map(agents.map((a) => [a.graph_id, a])); + const agentExecutions = groupByAgent(executions, graphIdToAgent); + const items: SitrepItemData[] = []; + + for (const [agent, execs] of agentExecutions) { + const item = buildSitrepFromExecutions(agent, execs); + if (item) items.push(item); + } + + const order: Record = { + error: 0, + running: 1, + stale: 2, + success: 3, + listening: 4, + scheduled: 5, + idle: 6, + }; + items.sort((a, b) => order[a.priority] - order[b.priority]); + + return items.slice(0, maxItems); + }, [agents, executions, maxItems]); +} + +function groupByAgent( + executions: GraphExecutionMeta[], + graphIdToAgent: Map, +): Map { + const map = new Map(); + + for (const exec of executions) { + const agent = graphIdToAgent.get(exec.graph_id); + if (!agent) continue; + const list = map.get(agent); + if (list) { + list.push(exec); + } else { + map.set(agent, [exec]); + } + } + + return map; +} + +function buildSitrepFromExecutions( + agent: LibraryAgent, + executions: GraphExecutionMeta[], +): SitrepItemData | null { + const active = executions.find((e) => isActive(e.status)); + if (active) { + return { + id: `${agent.id}-${active.id}`, + agentID: agent.id, + agentName: agent.name, + executionID: active.id, + priority: "running", + message: + active.stats?.activity_status ?? + runningMessage(active.status, active.started_at), + status: "running", + }; + } + + const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS; + const recent = executions + .filter((e) => endedAfter(e, cutoff)) + .sort((a, b) => toEndTime(b) - toEndTime(a)); + + const lastFailed = recent.find((e) => isFailed(e.status)); + if (lastFailed) { + const errorMsg = + lastFailed.stats?.error ?? + lastFailed.stats?.activity_status ?? + "Execution failed"; + return { + id: `${agent.id}-${lastFailed.id}`, + agentID: agent.id, + agentName: agent.name, + executionID: lastFailed.id, + priority: "error", + message: typeof errorMsg === "string" ? errorMsg : "Execution failed", + status: "error", + }; + } + + const lastCompleted = recent.find( + (e) => e.status === AgentExecutionStatus.COMPLETED, + ); + if (lastCompleted) { + const summary = + lastCompleted.stats?.activity_status ?? "Completed successfully"; + return { + id: `${agent.id}-${lastCompleted.id}`, + agentID: agent.id, + agentName: agent.name, + executionID: lastCompleted.id, + priority: "success", + message: typeof summary === "string" ? summary : "Completed successfully", + status: "idle", + }; + } + + return null; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/StatusBadge/StatusBadge.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/StatusBadge/StatusBadge.tsx new file mode 100644 index 0000000000..afcee51380 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/StatusBadge/StatusBadge.tsx @@ -0,0 +1,84 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import type { AgentStatus } from "../../types"; + +const STATUS_CONFIG: Record< + AgentStatus, + { label: string; bg: string; text: string; pulse: boolean } +> = { + running: { + label: "Running", + bg: "", + text: "text-blue-600", + pulse: true, + }, + error: { + label: "Error", + bg: "", + text: "text-red-500", + pulse: false, + }, + listening: { + label: "Listening", + bg: "", + text: "text-purple-500", + pulse: true, + }, + scheduled: { + label: "Scheduled", + bg: "", + text: "text-yellow-600", + pulse: false, + }, + idle: { + label: "Idle", + bg: "", + text: "text-zinc-500", + pulse: false, + }, +}; + +interface Props { + status: AgentStatus; + className?: string; +} + +export function StatusBadge({ status, className }: Props) { + const config = STATUS_CONFIG[status]; + + return ( + + + {config.label} + + ); +} + +function statusDotColor(status: AgentStatus): string { + switch (status) { + case "running": + return "bg-blue-500"; + case "error": + return "bg-red-500"; + case "listening": + return "bg-purple-500"; + case "scheduled": + return "bg-yellow-500"; + case "idle": + return "bg-zinc-400"; + } +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/hooks/executionHelpers.ts b/autogpt_platform/frontend/src/app/(platform)/library/hooks/executionHelpers.ts new file mode 100644 index 0000000000..cd2505c7ce --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/hooks/executionHelpers.ts @@ -0,0 +1,59 @@ +import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; +import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; + +export const SEVENTY_TWO_HOURS_MS = 72 * 60 * 60 * 1000; + +export function isActive(status: string): boolean { + return ( + status === AgentExecutionStatus.RUNNING || + status === AgentExecutionStatus.QUEUED || + status === AgentExecutionStatus.REVIEW + ); +} + +export function isFailed(status: string): boolean { + return ( + status === AgentExecutionStatus.FAILED || + status === AgentExecutionStatus.TERMINATED + ); +} + +export function toEndTime(exec: GraphExecutionMeta): number { + if (!exec.ended_at) return 0; + return exec.ended_at instanceof Date + ? exec.ended_at.getTime() + : new Date(exec.ended_at).getTime(); +} + +export function endedAfter(exec: GraphExecutionMeta, cutoff: number): boolean { + if (!exec.ended_at) return false; + return toEndTime(exec) > cutoff; +} + +export function runningMessage( + status: string, + startedAt?: string | Date | null, +): string { + if (status === AgentExecutionStatus.QUEUED) return "Queued for execution"; + if (status === AgentExecutionStatus.REVIEW) return "Awaiting review"; + if (!startedAt) return "Currently executing"; + const ms = + Date.now() - + (startedAt instanceof Date + ? startedAt.getTime() + : new Date(startedAt).getTime()); + return `Running for ${formatRelativeDuration(ms)}`; +} + +export function formatRelativeDuration(ms: number): string { + const seconds = Math.floor(ms / 1000); + if (seconds < 60) return "a few seconds"; + const minutes = Math.floor(seconds / 60); + if (minutes < 60) return `${minutes}m`; + const hours = Math.floor(minutes / 60); + const remainingMin = minutes % 60; + if (hours < 24) + return remainingMin > 0 ? `${hours}h ${remainingMin}m` : `${hours}h`; + const days = Math.floor(hours / 24); + return `${days}d ${hours % 24}h`; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/hooks/useAgentStatus.ts b/autogpt_platform/frontend/src/app/(platform)/library/hooks/useAgentStatus.ts new file mode 100644 index 0000000000..ada5560040 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/hooks/useAgentStatus.ts @@ -0,0 +1,213 @@ +"use client"; + +import { useMemo } from "react"; +import { useGetV1ListAllExecutions } from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; +import type { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { okData } from "@/app/api/helpers"; +import type { + AgentStatus, + AgentHealth, + AgentStatusInfo, + FleetSummary, +} from "../types"; +import { + isActive, + isFailed, + toEndTime, + SEVENTY_TWO_HOURS_MS, +} from "./executionHelpers"; + +function deriveHealth( + status: AgentStatus, + lastRunAt: string | null, +): AgentHealth { + if (status === "error") return "attention"; + if (status === "idle" && lastRunAt) { + const daysSince = + (Date.now() - new Date(lastRunAt).getTime()) / (1000 * 60 * 60 * 24); + if (daysSince > 14) return "stale"; + } + return "good"; +} + +function computeAgentStatus( + agent: LibraryAgent, + agentExecutions: GraphExecutionMeta[], +): AgentStatusInfo { + const activeExec = agentExecutions.find((e) => isActive(e.status)); + + let status: AgentStatus; + let lastError: string | null = null; + let lastRunAt: string | null = null; + const activeExecutionID = activeExec?.id ?? null; + + if (activeExec) { + status = "running"; + } else { + const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS; + const recentFailed = agentExecutions.find( + (e) => + isFailed(e.status) && + e.ended_at && + new Date( + e.ended_at instanceof Date ? e.ended_at.getTime() : e.ended_at, + ).getTime() > cutoff, + ); + + if (recentFailed) { + status = "error"; + lastError = + (recentFailed.stats?.error as string) ?? + (recentFailed.stats?.activity_status as string) ?? + "Execution failed"; + } else if (agent.has_external_trigger) { + status = "listening"; + } else if (agent.recommended_schedule_cron) { + status = "scheduled"; + } else { + status = "idle"; + } + } + + const completedExecs = agentExecutions.filter((e) => e.ended_at); + if (completedExecs.length > 0) { + const sorted = completedExecs.sort((a, b) => toEndTime(b) - toEndTime(a)); + const endedAt = sorted[0].ended_at; + lastRunAt = + endedAt instanceof Date ? endedAt.toISOString() : String(endedAt); + } + + const totalRuns = agent.execution_count ?? agentExecutions.length; + + return { + status, + health: deriveHealth(status, lastRunAt), + progress: null, + totalRuns, + lastRunAt, + lastError, + activeExecutionID, + monthlySpend: 0, + nextScheduledRun: null, + triggerType: agent.has_external_trigger ? "webhook" : null, + }; +} + +export function useAgentStatusMap( + agents: LibraryAgent[], +): Map { + const { data: executions } = useGetV1ListAllExecutions({ + query: { select: okData }, + }); + + return useMemo(() => { + const map = new Map(); + const execsByGraph = new Map(); + + for (const exec of executions ?? []) { + const list = execsByGraph.get(exec.graph_id); + if (list) { + list.push(exec); + } else { + execsByGraph.set(exec.graph_id, [exec]); + } + } + + for (const agent of agents) { + const agentExecs = execsByGraph.get(agent.graph_id) ?? []; + map.set(agent.graph_id, computeAgentStatus(agent, agentExecs)); + } + + return map; + }, [agents, executions]); +} + +const DEFAULT_STATUS: AgentStatusInfo = { + status: "idle", + health: "good", + progress: null, + totalRuns: 0, + lastRunAt: null, + lastError: null, + activeExecutionID: null, + monthlySpend: 0, + nextScheduledRun: null, + triggerType: null, +}; + +export function getAgentStatus( + statusMap: Map, + graphID: string, +): AgentStatusInfo { + return statusMap.get(graphID) ?? DEFAULT_STATUS; +} + +export function useFleetSummary(agents: LibraryAgent[]): FleetSummary { + const { data: executions } = useGetV1ListAllExecutions({ + query: { select: okData }, + }); + + return useMemo(() => { + const counts: FleetSummary = { + running: 0, + error: 0, + completed: 0, + listening: 0, + scheduled: 0, + idle: 0, + monthlySpend: 0, + }; + + const activeGraphIds = new Set(); + const errorGraphIds = new Set(); + const completedGraphIds = new Set(); + + if (executions) { + const cutoff = Date.now() - SEVENTY_TWO_HOURS_MS; + for (const exec of executions) { + if (isActive(exec.status)) { + activeGraphIds.add(exec.graph_id); + } + const endedTs = exec.ended_at + ? new Date( + exec.ended_at instanceof Date + ? exec.ended_at.getTime() + : exec.ended_at, + ).getTime() + : 0; + if (isFailed(exec.status) && endedTs > cutoff) { + errorGraphIds.add(exec.graph_id); + } + if ( + exec.status === AgentExecutionStatus.COMPLETED && + endedTs > cutoff + ) { + completedGraphIds.add(exec.graph_id); + } + } + } + + for (const agent of agents) { + if (activeGraphIds.has(agent.graph_id)) { + counts.running += 1; + } else if (errorGraphIds.has(agent.graph_id)) { + counts.error += 1; + } else if (agent.has_external_trigger) { + counts.listening += 1; + } else if (agent.recommended_schedule_cron) { + counts.scheduled += 1; + } else { + counts.idle += 1; + } + if (completedGraphIds.has(agent.graph_id)) { + counts.completed += 1; + } + } + + return counts; + }, [agents, executions]); +} + +export { deriveHealth }; diff --git a/autogpt_platform/frontend/src/app/(platform)/library/hooks/useLibraryFleetSummary.ts b/autogpt_platform/frontend/src/app/(platform)/library/hooks/useLibraryFleetSummary.ts new file mode 100644 index 0000000000..8aa7a92812 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/hooks/useLibraryFleetSummary.ts @@ -0,0 +1,116 @@ +"use client"; + +import { + getGetV1ListAllExecutionsQueryKey, + useGetV1ListAllExecutions, +} from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus"; +import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; +import { okData } from "@/app/api/helpers"; +import { useExecutionEvents } from "@/hooks/useExecutionEvents"; +import { useQueryClient } from "@tanstack/react-query"; +import { useCallback, useMemo } from "react"; +import type { FleetSummary } from "../types"; +import { isActive, isFailed, SEVENTY_TWO_HOURS_MS } from "./executionHelpers"; + +function isRecentFailure( + status: string, + endedAt?: string | Date | null, +): boolean { + if (!isFailed(status)) return false; + if (!endedAt) return false; + const ts = + endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime(); + return Date.now() - ts < SEVENTY_TWO_HOURS_MS; +} + +function isRecentCompletion( + status: string, + endedAt?: string | Date | null, +): boolean { + if (status !== AgentExecutionStatus.COMPLETED) return false; + if (!endedAt) return false; + const ts = + endedAt instanceof Date ? endedAt.getTime() : new Date(endedAt).getTime(); + return Date.now() - ts < SEVENTY_TWO_HOURS_MS; +} + +export function useLibraryFleetSummary( + agents: LibraryAgent[], +): FleetSummary | undefined { + const queryClient = useQueryClient(); + + const { data: executions, isSuccess } = useGetV1ListAllExecutions({ + query: { select: okData }, + }); + + const graphIDs = useMemo(() => agents.map((a) => a.graph_id), [agents]); + + const handleExecutionUpdate = useCallback(() => { + queryClient.invalidateQueries({ + queryKey: getGetV1ListAllExecutionsQueryKey(), + }); + }, [queryClient]); + + useExecutionEvents({ + graphIds: graphIDs.length > 0 ? graphIDs : undefined, + enabled: graphIDs.length > 0, + onExecutionUpdate: handleExecutionUpdate, + }); + + return useMemo(() => { + if (!isSuccess || !executions) return undefined; + + const agentsWithActiveExecution = new Set(); + const agentsWithRecentFailure = new Set(); + const agentsWithRecentCompletion = new Set(); + + for (const exec of executions) { + if (isActive(exec.status)) { + agentsWithActiveExecution.add(exec.graph_id); + } + if (isRecentFailure(exec.status, exec.ended_at)) { + agentsWithRecentFailure.add(exec.graph_id); + } + if (isRecentCompletion(exec.status, exec.ended_at)) { + agentsWithRecentCompletion.add(exec.graph_id); + } + } + + const summary: FleetSummary = { + running: 0, + error: 0, + completed: 0, + listening: 0, + scheduled: 0, + idle: 0, + monthlySpend: 0, + }; + + for (const agent of agents) { + if (agentsWithActiveExecution.has(agent.graph_id)) { + summary.running += 1; + } else if (agentsWithRecentFailure.has(agent.graph_id)) { + summary.error += 1; + } else if (agent.has_external_trigger) { + summary.listening += 1; + } else if (agent.recommended_schedule_cron) { + summary.scheduled += 1; + } else { + summary.idle += 1; + } + // Parallel counter: mutually exclusive with running/error (which match + // the sitrep priority order used by the "Recently completed" tab list) + // but orthogonal to listening/scheduled/idle. + if ( + !agentsWithActiveExecution.has(agent.graph_id) && + !agentsWithRecentFailure.has(agent.graph_id) && + agentsWithRecentCompletion.has(agent.graph_id) + ) { + summary.completed += 1; + } + } + + return summary; + }, [agents, executions, isSuccess]); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/library/page.tsx b/autogpt_platform/frontend/src/app/(platform)/library/page.tsx index f88c4a64dd..b660999520 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/page.tsx @@ -2,12 +2,14 @@ import { useEffect, useState, useCallback } from "react"; import { HeartIcon, ListIcon } from "@phosphor-icons/react"; -import { JumpBackIn } from "./components/JumpBackIn/JumpBackIn"; import { LibraryActionHeader } from "./components/LibraryActionHeader/LibraryActionHeader"; import { LibraryAgentList } from "./components/LibraryAgentList/LibraryAgentList"; import { useLibraryListPage } from "./components/useLibraryListPage"; import { FavoriteAnimationProvider } from "./context/FavoriteAnimationContext"; -import { LibraryTab } from "./types"; +import type { LibraryTab, AgentStatusFilter } from "./types"; +import { useLibraryFleetSummary } from "./hooks/useLibraryFleetSummary"; +import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; +import { useLibraryAgents } from "@/hooks/useLibraryAgents/useLibraryAgents"; const LIBRARY_TABS: LibraryTab[] = [ { id: "all", title: "All", icon: ListIcon }, @@ -19,6 +21,10 @@ export default function LibraryPage() { useLibraryListPage(); const [selectedFolderId, setSelectedFolderId] = useState(null); const [activeTab, setActiveTab] = useState(LIBRARY_TABS[0].id); + const [statusFilter, setStatusFilter] = useState("all"); + const isAgentBriefingEnabled = useGetFlag(Flag.AGENT_BRIEFING); + const { agents } = useLibraryAgents(); + const fleetSummary = useLibraryFleetSummary(agents); useEffect(() => { document.title = "Library – AutoGPT Platform"; @@ -40,7 +46,6 @@ export default function LibraryPage() { >
-
diff --git a/autogpt_platform/frontend/src/app/(platform)/library/types.ts b/autogpt_platform/frontend/src/app/(platform)/library/types.ts index dad4096fc4..b5253b41bc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/library/types.ts @@ -1,7 +1,76 @@ -import { Icon } from "@phosphor-icons/react"; +import type { Icon } from "@phosphor-icons/react"; export interface LibraryTab { id: string; title: string; icon: Icon; } + +/** Agent execution status — drives StatusBadge visuals & filtering. */ +export type AgentStatus = + | "running" + | "error" + | "listening" + | "scheduled" + | "idle"; + +/** Derived health bucket for quick triage. */ +export type AgentHealth = "good" | "attention" | "stale"; + +/** Real-time metadata that powers the Intelligence Layer features. */ +export interface AgentStatusInfo { + status: AgentStatus; + health: AgentHealth; + /** 0-100 progress for currently running agents. */ + progress: number | null; + totalRuns: number; + lastRunAt: string | null; + lastError: string | null; + /** ID of the currently active execution (when status is "running"). */ + activeExecutionID: string | null; + monthlySpend: number; + nextScheduledRun: string | null; + triggerType: string | null; +} + +/** Fleet-wide aggregate counts used by the Briefing Panel stats grid. */ +export interface FleetSummary { + running: number; + error: number; + completed: number; + listening: number; + scheduled: number; + idle: number; + monthlySpend: number; +} + +export type SitrepPriority = + | "error" + | "running" + | "stale" + | "success" + | "listening" + | "scheduled" + | "idle"; + +export interface SitrepItemData { + id: string; + agentID: string; + agentName: string; + agentImageUrl?: string | null; + executionID?: string; + priority: SitrepPriority; + message: string; + status: AgentStatus; +} + +/** Filter options for the agent filter dropdown. */ +export type AgentStatusFilter = + | "all" + | "running" + | "attention" + | "completed" + | "listening" + | "scheduled" + | "idle" + | "healthy"; diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx index 774fe01ed9..58a4b9d58b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx @@ -1,6 +1,8 @@ "use client"; import { useState } from "react"; import { Button } from "@/components/ui/button"; +import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { Skeleton } from "@/components/atoms/Skeleton/Skeleton"; import { useSubscriptionTierSection } from "./useSubscriptionTierSection"; type TierInfo = { @@ -15,39 +17,70 @@ const TIERS: TierInfo[] = [ key: "FREE", label: "Free", multiplier: "1x", - description: "Base rate limits", + description: "Base AutoPilot capacity with standard rate limits", }, { key: "PRO", label: "Pro", multiplier: "5x", - description: "5x more AutoPilot capacity", + description: "5x AutoPilot capacity — run 5× more tasks per day/week", }, { key: "BUSINESS", label: "Business", multiplier: "20x", - description: "20x more AutoPilot capacity", + description: "20x AutoPilot capacity — ideal for teams and heavy workloads", }, ]; -function formatCost(cents: number): string { - if (cents === 0) return "Free"; +const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; + +function formatCost(cents: number, tierKey: string): string { + if (tierKey === "FREE") return "Free"; + if (cents === 0) return "Pricing available soon"; return `$${(cents / 100).toFixed(2)}/mo`; } export function SubscriptionTierSection() { - const { subscription, isLoading, error, isPending, changeTier } = - useSubscriptionTierSection(); - const [tierError, setTierError] = useState(null); + const { + subscription, + isLoading, + error, + tierError, + isPending, + pendingTier, + pendingUpgradeTier, + setPendingUpgradeTier, + confirmUpgrade, + isPaymentEnabled, + changeTier, + handleTierChange, + } = useSubscriptionTierSection(); + const [confirmDowngradeTo, setConfirmDowngradeTo] = useState( + null, + ); - if (isLoading) return null; + if (isLoading) { + return ( +
+ +
+ + + +
+
+ ); + } if (error) { return (

Subscription Plan

-

+

{error}

@@ -56,10 +89,30 @@ export function SubscriptionTierSection() { if (!subscription) return null; - async function handleTierChange(tierKey: string) { - setTierError(null); - const err = await changeTier(tierKey); - if (err) setTierError(err); + const currentTier = subscription.tier; + + if (currentTier === "ENTERPRISE") { + return ( +
+

Subscription Plan

+
+

+ Enterprise Plan +

+

+ Your Enterprise plan is managed by your administrator. Contact your + account team for changes. +

+
+
+ ); + } + + async function confirmDowngrade() { + if (!confirmDowngradeTo) return; + const tier = confirmDowngradeTo; + setConfirmDowngradeTo(null); + await changeTier(tier); } return ( @@ -67,24 +120,28 @@ export function SubscriptionTierSection() {

Subscription Plan

{tierError && ( -

+

{tierError}

)}
{TIERS.map((tier) => { - const isCurrent = subscription.tier === tier.key; + const isCurrent = currentTier === tier.key; const cost = subscription.tier_costs[tier.key] ?? 0; - const currentTierOrder = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; - const currentIdx = currentTierOrder.indexOf(subscription.tier); - const targetIdx = currentTierOrder.indexOf(tier.key); + const currentIdx = TIER_ORDER.indexOf(currentTier); + const targetIdx = TIER_ORDER.indexOf(tier.key); const isUpgrade = targetIdx > currentIdx; const isDowngrade = targetIdx < currentIdx; + const isThisPending = pendingTier === tier.key; return (
-

{formatCost(cost)}

+

+ {formatCost(cost, tier.key)} +

{tier.multiplier} rate limits

@@ -108,14 +167,20 @@ export function SubscriptionTierSection() { {tier.description}

- {!isCurrent && ( + {!isCurrent && isPaymentEnabled && (
- {subscription.tier !== "FREE" && ( + {currentTier !== "FREE" && isPaymentEnabled && (

- Your subscription is managed through Stripe. Changes take effect - immediately. + Your subscription is managed through Stripe. Upgrades and paid-tier + changes take effect immediately; downgrades to Free are scheduled for + the end of the current billing period.

)} + + { + if (!open) setConfirmDowngradeTo(null); + }, + }} + > + +

+ {confirmDowngradeTo === "FREE" + ? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then." + : `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "} + Are you sure? +

+ + + + +
+
+ + { + if (!open) setPendingUpgradeTier(null); + }, + }} + > + +

+ {subscription && + subscription.proration_credit_cents > 0 && + `Your unused ${currentTier.charAt(0) + currentTier.slice(1).toLowerCase()} subscription ($${(subscription.proration_credit_cents / 100).toFixed(2)}) will be applied as a credit to your next Stripe invoice. `} + You will be redirected to Stripe to complete your upgrade to{" "} + {TIERS.find((t) => t.key === pendingUpgradeTier)?.label ?? + pendingUpgradeTier} + . +

+ + + + +
+
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/__tests__/SubscriptionTierSection.test.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/__tests__/SubscriptionTierSection.test.tsx new file mode 100644 index 0000000000..086c383337 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/__tests__/SubscriptionTierSection.test.tsx @@ -0,0 +1,358 @@ +import { + render, + screen, + fireEvent, + waitFor, + cleanup, +} from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { SubscriptionTierSection } from "../SubscriptionTierSection"; + +// Mock next/navigation +const mockSearchParams = new URLSearchParams(); +const mockRouterReplace = vi.fn(); +vi.mock("next/navigation", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + useSearchParams: () => mockSearchParams, + useRouter: () => ({ push: vi.fn(), replace: mockRouterReplace }), + usePathname: () => "/profile/credits", + }; +}); + +// Mock toast +const mockToast = vi.fn(); +vi.mock("@/components/molecules/Toast/use-toast", () => ({ + useToast: () => ({ toast: mockToast }), +})); + +// Mock feature flags — default to payment enabled so button tests work +let mockPaymentEnabled = true; +vi.mock("@/services/feature-flags/use-get-flag", () => ({ + Flag: { ENABLE_PLATFORM_PAYMENT: "enable-platform-payment" }, + useGetFlag: () => mockPaymentEnabled, +})); + +// Mock generated API hooks +const mockUseGetSubscriptionStatus = vi.fn(); +const mockUseUpdateSubscriptionTier = vi.fn(); +vi.mock("@/app/api/__generated__/endpoints/credits/credits", () => ({ + useGetSubscriptionStatus: (opts: unknown) => + mockUseGetSubscriptionStatus(opts), + useUpdateSubscriptionTier: () => mockUseUpdateSubscriptionTier(), +})); + +// Mock Dialog (Radix portals don't work in happy-dom) +const MockDialogContent = ({ children }: { children: React.ReactNode }) => ( +
{children}
+); +const MockDialogFooter = ({ children }: { children: React.ReactNode }) => ( +
{children}
+); +function MockDialog({ + controlled, + children, +}: { + controlled?: { isOpen: boolean; set: (open: boolean) => void }; + children: React.ReactNode; + [key: string]: unknown; +}) { + return controlled?.isOpen ?
{children}
: null; +} +MockDialog.Content = MockDialogContent; +MockDialog.Footer = MockDialogFooter; +vi.mock("@/components/molecules/Dialog/Dialog", () => ({ + Dialog: MockDialog, +})); + +function makeSubscription({ + tier = "FREE", + monthlyCost = 0, + tierCosts = { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 }, + prorationCreditCents = 0, +}: { + tier?: string; + monthlyCost?: number; + tierCosts?: Record; + prorationCreditCents?: number; +} = {}) { + return { + tier, + monthly_cost: monthlyCost, + tier_costs: tierCosts, + proration_credit_cents: prorationCreditCents, + }; +} + +function setupMocks({ + subscription = makeSubscription(), + isLoading = false, + queryError = null as Error | null, + mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }), + isPending = false, + variables = undefined as { data?: { tier?: string } } | undefined, +} = {}) { + // The hook uses select: (data) => (data.status === 200 ? data.data : null) + // so the data value returned by the hook is already the transformed subscription object. + // We simulate that by returning the subscription directly as data. + mockUseGetSubscriptionStatus.mockReturnValue({ + data: subscription, + isLoading, + error: queryError, + refetch: vi.fn(), + }); + mockUseUpdateSubscriptionTier.mockReturnValue({ + mutateAsync: mutateFn, + isPending, + variables, + }); +} + +afterEach(() => { + cleanup(); + mockUseGetSubscriptionStatus.mockReset(); + mockUseUpdateSubscriptionTier.mockReset(); + mockToast.mockReset(); + mockRouterReplace.mockReset(); + mockSearchParams.delete("subscription"); + mockPaymentEnabled = true; +}); + +describe("SubscriptionTierSection", () => { + it("renders skeleton cards while loading", () => { + setupMocks({ isLoading: true }); + render(); + // Just verify we're rendering something (not null) and no tier cards + expect(screen.queryByText("Pro")).toBeNull(); + expect(screen.queryByText("Business")).toBeNull(); + }); + + it("renders error message when subscription fetch fails", () => { + setupMocks({ + queryError: new Error("Network error"), + subscription: makeSubscription(), + }); + // Override the data to simulate failed state + mockUseGetSubscriptionStatus.mockReturnValue({ + data: null, + isLoading: false, + error: new Error("Network error"), + refetch: vi.fn(), + }); + render(); + expect(screen.getByRole("alert")).toBeDefined(); + expect(screen.getByText(/failed to load subscription info/i)).toBeDefined(); + }); + + it("renders all three tier cards for FREE user", () => { + setupMocks(); + render(); + // Use getAllByText to account for the tier label AND cost display both containing "Free" + expect(screen.getAllByText("Free").length).toBeGreaterThan(0); + expect(screen.getByText("Pro")).toBeDefined(); + expect(screen.getByText("Business")).toBeDefined(); + }); + + it("shows Current badge on the active tier", () => { + setupMocks({ subscription: makeSubscription({ tier: "PRO" }) }); + render(); + expect(screen.getByText("Current")).toBeDefined(); + // Upgrade to PRO button should NOT exist; Upgrade to BUSINESS and Downgrade to Free should + expect( + screen.queryByRole("button", { name: /upgrade to pro/i }), + ).toBeNull(); + expect( + screen.getByRole("button", { name: /upgrade to business/i }), + ).toBeDefined(); + expect( + screen.getByRole("button", { name: /downgrade to free/i }), + ).toBeDefined(); + }); + + it("displays tier costs from the API", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "FREE", + tierCosts: { FREE: 0, PRO: 1999, BUSINESS: 4999, ENTERPRISE: 0 }, + }), + }); + render(); + expect(screen.getByText("$19.99/mo")).toBeDefined(); + expect(screen.getByText("$49.99/mo")).toBeDefined(); + // FREE tier label should still be visible (there may be multiple "Free" elements) + expect(screen.getAllByText("Free").length).toBeGreaterThan(0); + }); + + it("shows 'Pricing available soon' when tier cost is 0 for a paid tier", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "FREE", + tierCosts: { FREE: 0, PRO: 0, BUSINESS: 0, ENTERPRISE: 0 }, + }), + }); + render(); + // PRO and BUSINESS with cost=0 should show "Pricing available soon" + expect(screen.getAllByText("Pricing available soon")).toHaveLength(2); + }); + + it("calls changeTier on upgrade click after confirmation dialog", async () => { + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ mutateFn }); + render(); + + // Clicking upgrade opens the confirmation dialog first + fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i })); + // Confirm via the dialog's "Continue to Checkout" button + fireEvent.click( + screen.getByRole("button", { name: /continue to checkout/i }), + ); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "PRO" }), + }), + ); + }); + }); + + it("shows confirmation dialog on downgrade click", () => { + setupMocks({ subscription: makeSubscription({ tier: "PRO" }) }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + + expect(screen.getByRole("dialog")).toBeDefined(); + // The dialog title text appears in both a div and a button — just check the dialog is open + expect(screen.getAllByText(/confirm downgrade/i).length).toBeGreaterThan(0); + }); + + it("calls changeTier after downgrade confirmation", async () => { + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ + subscription: makeSubscription({ tier: "PRO" }), + mutateFn, + }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i })); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "FREE" }), + }), + ); + }); + }); + + it("dismisses dialog when Cancel is clicked", () => { + setupMocks({ subscription: makeSubscription({ tier: "PRO" }) }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + expect(screen.getByRole("dialog")).toBeDefined(); + + fireEvent.click(screen.getByRole("button", { name: /^cancel$/i })); + expect(screen.queryByRole("dialog")).toBeNull(); + }); + + it("redirects to Stripe when checkout URL is returned", async () => { + // Replace window.location with a plain object so assigning .href doesn't + // trigger jsdom navigation (which would throw or reload the test page). + const mockLocation = { href: "" }; + vi.stubGlobal("location", mockLocation); + + const mutateFn = vi.fn().mockResolvedValue({ + status: 200, + data: { url: "https://checkout.stripe.com/pay/cs_test" }, + }); + setupMocks({ mutateFn }); + render(); + + // Upgrade opens confirmation dialog first — confirm via "Continue to Checkout" + fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i })); + fireEvent.click( + screen.getByRole("button", { name: /continue to checkout/i }), + ); + + await waitFor(() => { + expect(mockLocation.href).toBe("https://checkout.stripe.com/pay/cs_test"); + }); + + vi.unstubAllGlobals(); + }); + + it("shows an error alert when tier change fails", async () => { + const mutateFn = vi.fn().mockRejectedValue(new Error("Stripe unavailable")); + setupMocks({ mutateFn }); + render(); + + // Upgrade opens confirmation dialog first — confirm to trigger the mutation + fireEvent.click(screen.getByRole("button", { name: /upgrade to pro/i })); + fireEvent.click( + screen.getByRole("button", { name: /continue to checkout/i }), + ); + + await waitFor(() => { + expect(screen.getByRole("alert")).toBeDefined(); + expect(screen.getByText(/stripe unavailable/i)).toBeDefined(); + }); + }); + + it("hides action buttons when payment flag is disabled", () => { + mockPaymentEnabled = false; + setupMocks({ subscription: makeSubscription({ tier: "FREE" }) }); + render(); + // Tier cards still visible + expect(screen.getByText("Pro")).toBeDefined(); + expect(screen.getByText("Business")).toBeDefined(); + // No upgrade/downgrade buttons + expect(screen.queryByRole("button", { name: /upgrade/i })).toBeNull(); + expect(screen.queryByRole("button", { name: /downgrade/i })).toBeNull(); + }); + + it("shows ENTERPRISE message for ENTERPRISE tier users", () => { + setupMocks({ subscription: makeSubscription({ tier: "ENTERPRISE" }) }); + render(); + // Enterprise heading text appears in a

(may match multiple), just verify it exists + expect(screen.getAllByText(/enterprise plan/i).length).toBeGreaterThan(0); + expect(screen.getByText(/managed by your administrator/i)).toBeDefined(); + // No standard tier cards should be rendered + expect(screen.queryByText("Pro")).toBeNull(); + expect(screen.queryByText("Business")).toBeNull(); + }); + + it("shows success toast and clears URL param when ?subscription=success is present", async () => { + mockSearchParams.set("subscription", "success"); + setupMocks(); + render(); + + await waitFor(() => { + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ title: "Subscription upgraded" }), + ); + }); + // URL param must be stripped so a page refresh doesn't re-trigger the toast + expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits"); + }); + + it("clears URL param but shows no toast when ?subscription=cancelled is present", async () => { + mockSearchParams.set("subscription", "cancelled"); + setupMocks(); + render(); + + // The cancelled param must be stripped from the URL (same hygiene as success) + await waitFor(() => { + expect(mockRouterReplace).toHaveBeenCalledWith("/profile/credits"); + }); + // No toast should fire — the user simply abandoned checkout + expect(mockToast).not.toHaveBeenCalled(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts index b0fe635b72..862551c7e3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts @@ -1,13 +1,30 @@ +import { useEffect, useState } from "react"; +import { usePathname, useRouter, useSearchParams } from "next/navigation"; import { useGetSubscriptionStatus, useUpdateSubscriptionTier, } from "@/app/api/__generated__/endpoints/credits/credits"; import type { SubscriptionStatusResponse } from "@/app/api/__generated__/models/subscriptionStatusResponse"; import type { SubscriptionTierRequestTier } from "@/app/api/__generated__/models/subscriptionTierRequestTier"; +import { useToast } from "@/components/molecules/Toast/use-toast"; +import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; export type SubscriptionStatus = SubscriptionStatusResponse; +const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; + export function useSubscriptionTierSection() { + const isPaymentEnabled = useGetFlag(Flag.ENABLE_PLATFORM_PAYMENT); + const searchParams = useSearchParams(); + const subscriptionStatus = searchParams.get("subscription"); + const router = useRouter(); + const pathname = usePathname(); + const { toast } = useToast(); + const [tierError, setTierError] = useState(null); + const [pendingUpgradeTier, setPendingUpgradeTier] = useState( + null, + ); + const { data: subscription, isLoading, @@ -17,11 +34,39 @@ export function useSubscriptionTierSection() { query: { select: (data) => (data.status === 200 ? data.data : null) }, }); - const error = queryError ? "Failed to load subscription info" : null; + const fetchError = queryError ? "Failed to load subscription info" : null; - const { mutateAsync: doUpdateTier, isPending } = useUpdateSubscriptionTier(); + const { + mutateAsync: doUpdateTier, + isPending, + variables, + } = useUpdateSubscriptionTier(); - async function changeTier(tier: string): Promise { + useEffect(() => { + if (subscriptionStatus === "success") { + refetch(); + toast({ + title: "Subscription upgraded", + description: + "Your plan has been updated. It may take a moment to reflect.", + }); + } + // Strip ?subscription=success|cancelled from the URL so a page refresh + // does not re-trigger side-effects, and so a second checkout in the same + // session correctly fires the toast again. + if ( + subscriptionStatus === "success" || + subscriptionStatus === "cancelled" + ) { + router.replace(pathname); + } + // eslint-disable-next-line react-hooks/exhaustive-deps -- refetch and toast + // are new references each render but are stable in practice; the effect must + // only re-run when subscriptionStatus/pathname changes. + }, [subscriptionStatus, refetch, toast, router, pathname]); + + async function changeTier(tier: string) { + setTierError(null); try { const successUrl = `${window.location.origin}${window.location.pathname}?subscription=success`; const cancelUrl = `${window.location.origin}${window.location.pathname}?subscription=cancelled`; @@ -34,22 +79,59 @@ export function useSubscriptionTierSection() { }); if (result.status === 200 && result.data.url) { window.location.href = result.data.url; - return null; + return; } await refetch(); - return null; + toast({ + title: "Subscription updated", + description: + tier === "FREE" + ? "Your plan will be downgraded to Free at the end of your current billing period." + : "Your subscription has been updated.", + }); } catch (e: unknown) { const msg = e instanceof Error ? e.message : "Failed to change subscription tier"; - return msg; + setTierError(msg); } } + function handleTierChange( + targetTierKey: string, + currentTier: string, + onConfirmDowngrade: (tier: string) => void, + ) { + const currentIdx = TIER_ORDER.indexOf(currentTier); + const targetIdx = TIER_ORDER.indexOf(targetTierKey); + if (targetIdx < currentIdx) { + onConfirmDowngrade(targetTierKey); + return; + } + setPendingUpgradeTier(targetTierKey); + } + + async function confirmUpgrade() { + if (!pendingUpgradeTier) return; + const tier = pendingUpgradeTier; + setPendingUpgradeTier(null); + await changeTier(tier); + } + + const pendingTier = + isPending && variables?.data?.tier ? variables.data.tier : null; + return { subscription: subscription ?? null, isLoading, - error, + error: fetchError, + tierError, isPending, + pendingTier, + pendingUpgradeTier, + setPendingUpgradeTier, + confirmUpgrade, + isPaymentEnabled, changeTier, + handleTierChange, }; } diff --git a/autogpt_platform/frontend/src/app/api/auth/user/route.ts b/autogpt_platform/frontend/src/app/api/auth/user/route.ts index 896385d865..63cef27fc5 100644 --- a/autogpt_platform/frontend/src/app/api/auth/user/route.ts +++ b/autogpt_platform/frontend/src/app/api/auth/user/route.ts @@ -15,15 +15,35 @@ export async function GET() { export async function PUT(request: Request) { try { const supabase = await getServerSupabase(); - const { email } = await request.json(); - if (!email) { - return NextResponse.json({ error: "Email is required" }, { status: 400 }); + let body: unknown; + try { + body = await request.json(); + } catch { + return NextResponse.json({ error: "Invalid JSON body" }, { status: 400 }); } - const { data, error } = await supabase.auth.updateUser({ - email, - }); + const { email: rawEmail, full_name: rawFullName } = body as { + email?: unknown; + full_name?: unknown; + }; + + const email = typeof rawEmail === "string" ? rawEmail.trim() : undefined; + const fullName = + typeof rawFullName === "string" ? rawFullName.trim() : undefined; + + if (!email && !fullName) { + return NextResponse.json( + { error: "Email or full_name is required" }, + { status: 400 }, + ); + } + + const updatePayload: Parameters[0] = {}; + if (email) updatePayload.email = email; + if (fullName) updatePayload.data = { full_name: fullName }; + + const { data, error } = await supabase.auth.updateUser(updatePayload); if (error) { return NextResponse.json({ error: error.message }, { status: 400 }); @@ -32,7 +52,7 @@ export async function PUT(request: Request) { return NextResponse.json(data); } catch { return NextResponse.json( - { error: "Failed to update user email" }, + { error: "Failed to update user" }, { status: 500 }, ); } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 6432ab79cf..82ac825ced 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1501,7 +1501,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit`` and ``before_sequence``.\nWhen no pagination params are provided, returns the most recent messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The authenticated user's ID.\n limit: Maximum number of messages to return (1-200, default 50).\n before_sequence: Return messages with sequence < this value (cursor).\n\nReturns:\n SessionDetailResponse: Details for the requested session, including\n active_stream info and pagination metadata.", + "description": "Retrieve the details of a specific chat session.\n\nSupports cursor-based pagination via ``limit``, ``before_sequence``, and\n``after_sequence``. The two cursor parameters are mutually exclusive.\n\nOn the initial load (no cursor provided) of a completed session, messages\nare returned in forward order starting from sequence 0 so the user always\nsees their initial prompt. Active sessions use the legacy newest-first\norder so streaming context is preserved.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1519,9 +1519,11 @@ "type": "integer", "maximum": 200, "minimum": 1, + "description": "Maximum number of messages to return.", "default": 50, "title": "Limit" - } + }, + "description": "Maximum number of messages to return." }, { "name": "before_sequence", @@ -1532,8 +1534,24 @@ { "type": "integer", "minimum": 0 }, { "type": "null" } ], + "description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence.", "title": "Before Sequence" - } + }, + "description": "Backward pagination cursor. Return messages with sequence number strictly less than this value. Used by active-session load-more. Mutually exclusive with after_sequence." + }, + { + "name": "after_sequence", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { "type": "integer", "minimum": 0 }, + { "type": "null" } + ], + "description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence.", + "title": "After Sequence" + }, + "description": "Forward pagination cursor. Return messages with sequence number strictly greater than this value. Used by completed-session load-more. Mutually exclusive with before_sequence." } ], "responses": { @@ -13351,6 +13369,15 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Oldest Sequence" }, + "newest_sequence": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Newest Sequence" + }, + "forward_paginated": { + "type": "boolean", + "title": "Forward Paginated", + "default": false + }, "total_prompt_tokens": { "type": "integer", "title": "Total Prompt Tokens", @@ -14200,16 +14227,29 @@ }, "SubscriptionStatusResponse": { "properties": { - "tier": { "type": "string", "title": "Tier" }, + "tier": { + "type": "string", + "enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"], + "title": "Tier" + }, "monthly_cost": { "type": "integer", "title": "Monthly Cost" }, "tier_costs": { "additionalProperties": { "type": "integer" }, "type": "object", "title": "Tier Costs" + }, + "proration_credit_cents": { + "type": "integer", + "title": "Proration Credit Cents" } }, "type": "object", - "required": ["tier", "monthly_cost", "tier_costs"], + "required": [ + "tier", + "monthly_cost", + "tier_costs", + "proration_credit_cents" + ], "title": "SubscriptionStatusResponse" }, "SubscriptionTier": { diff --git a/autogpt_platform/frontend/src/app/layout.tsx b/autogpt_platform/frontend/src/app/layout.tsx index f793d7dc2b..df67b9d0c2 100644 --- a/autogpt_platform/frontend/src/app/layout.tsx +++ b/autogpt_platform/frontend/src/app/layout.tsx @@ -12,6 +12,7 @@ import { Toaster } from "@/components/molecules/Toast/toaster"; import { SetupAnalytics } from "@/services/analytics"; import { VercelAnalyticsWrapper } from "@/services/analytics/VercelAnalyticsWrapper"; import { environment } from "@/services/environment"; +import AgentationDevtool from "@/components/AgentationDevtool"; import { ReactQueryDevtools } from "@tanstack/react-query-devtools"; import { headers } from "next/headers"; @@ -77,6 +78,7 @@ export default async function RootLayout({

+ {(isLocal || isDev) && } diff --git a/autogpt_platform/frontend/src/components/AgentationDevtool.tsx b/autogpt_platform/frontend/src/components/AgentationDevtool.tsx new file mode 100644 index 0000000000..82b59c78e8 --- /dev/null +++ b/autogpt_platform/frontend/src/components/AgentationDevtool.tsx @@ -0,0 +1,12 @@ +"use client"; + +import dynamic from "next/dynamic"; + +const Agentation = dynamic( + () => import("agentation").then((mod) => mod.Agentation), + { ssr: false }, +); + +export default function AgentationDevtool() { + return ; +} diff --git a/autogpt_platform/frontend/src/contexts/AutoPilotBridgeContext.tsx b/autogpt_platform/frontend/src/contexts/AutoPilotBridgeContext.tsx new file mode 100644 index 0000000000..6fee4c1a1a --- /dev/null +++ b/autogpt_platform/frontend/src/contexts/AutoPilotBridgeContext.tsx @@ -0,0 +1,64 @@ +"use client"; + +import { createContext, useContext, useState } from "react"; +import { useRouter } from "next/navigation"; + +const STORAGE_KEY = "autopilot_pending_prompt"; + +interface AutoPilotBridgeState { + pendingPrompt: string | null; + sendPrompt: (prompt: string) => void; + consumePrompt: () => string | null; +} + +const AutoPilotBridgeContext = createContext(null); + +interface Props { + children: React.ReactNode; +} + +export function AutoPilotBridgeProvider({ children }: Props) { + const router = useRouter(); + + const [pendingPrompt, setPendingPrompt] = useState(() => { + if (typeof window === "undefined") return null; + return sessionStorage.getItem(STORAGE_KEY); + }); + + function sendPrompt(prompt: string) { + sessionStorage.setItem(STORAGE_KEY, prompt); + setPendingPrompt(prompt); + router.push("/"); + } + + function consumePrompt(): string | null { + const prompt = pendingPrompt ?? sessionStorage.getItem(STORAGE_KEY); + if (prompt !== null) { + sessionStorage.removeItem(STORAGE_KEY); + setPendingPrompt(null); + } + return prompt; + } + + return ( + + {children} + + ); +} + +export function useAutoPilotBridge(): AutoPilotBridgeState { + const context = useContext(AutoPilotBridgeContext); + if (!context) { + // Return a no-op implementation when used outside the provider + // (e.g. in tests or isolated component renders). + return { + pendingPrompt: null, + sendPrompt: () => {}, + consumePrompt: () => null, + }; + } + return context; +} diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts index 9b51f2156f..961776e79e 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts @@ -194,26 +194,6 @@ export default class BackendAPI { return this._request("PATCH", "/credits"); } - getSubscription(): Promise<{ - tier: string; - monthly_cost: number; - tier_costs: Record; - }> { - return this._get("/credits/subscription"); - } - - setSubscriptionTier( - tier: string, - successUrl?: string, - cancelUrl?: string, - ): Promise<{ url: string }> { - return this._request("POST", "/credits/subscription", { - tier, - success_url: successUrl ?? "", - cancel_url: cancelUrl ?? "", - }); - } - //////////////////////////////////////// //////////////// GRAPHS //////////////// //////////////////////////////////////// diff --git a/autogpt_platform/frontend/src/playwright/library-happy-path.spec.ts b/autogpt_platform/frontend/src/playwright/library-happy-path.spec.ts index f7ed0e796c..17c02397a6 100644 --- a/autogpt_platform/frontend/src/playwright/library-happy-path.spec.ts +++ b/autogpt_platform/frontend/src/playwright/library-happy-path.spec.ts @@ -385,7 +385,9 @@ test("library happy path: user can edit a saved agent from Library and keep chan .context() .waitForEvent("page", { timeout: 10000 }) .catch(() => null); - await agentCard + // "Edit agent" link is inside the three-dot dropdown menu + await agentCard.getByRole("button", { name: "More actions" }).first().click(); + await page .getByTestId("library-agent-card-open-in-builder-link") .first() .click(); diff --git a/autogpt_platform/frontend/src/playwright/pages/library.page.ts b/autogpt_platform/frontend/src/playwright/pages/library.page.ts index 85c3f3978a..f2e648b341 100644 --- a/autogpt_platform/frontend/src/playwright/pages/library.page.ts +++ b/autogpt_platform/frontend/src/playwright/pages/library.page.ts @@ -262,13 +262,19 @@ export class LibraryPage extends BasePage { async clickOpenInBuilder(agent: Agent): Promise { console.log(`clicking open in builder for agent: ${agent.name}`); - const { getId } = getSelectors(this.page); - const agentCard = getId("library-agent-card").filter({ - hasText: agent.name, + const agentCard = this.page + .getByTestId("library-agent-card") + .filter({ hasText: agent.name }); + + // The "Edit agent" link is inside the three-dot dropdown menu. + // Open the menu first, then click the builder link. + const menuTrigger = agentCard.getByRole("button", { + name: "More actions", }); - const builderLink = getId( + await menuTrigger.first().click(); + + const builderLink = this.page.getByTestId( "library-agent-card-open-in-builder-link", - agentCard, ); await builderLink.first().click(); } diff --git a/autogpt_platform/frontend/src/services/feature-flags/use-get-flag.ts b/autogpt_platform/frontend/src/services/feature-flags/use-get-flag.ts index e16f5b765a..78c82acc5c 100644 --- a/autogpt_platform/frontend/src/services/feature-flags/use-get-flag.ts +++ b/autogpt_platform/frontend/src/services/feature-flags/use-get-flag.ts @@ -11,6 +11,7 @@ export enum Flag { ARTIFACTS = "artifacts", CHAT_MODE_OPTION = "chat-mode-option", BUILDER_CHAT_PANEL = "builder-chat-panel", + AGENT_BRIEFING = "agent-briefing", } const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true"; @@ -22,6 +23,7 @@ const defaultFlags = { [Flag.ARTIFACTS]: false, [Flag.CHAT_MODE_OPTION]: false, [Flag.BUILDER_CHAT_PANEL]: false, + [Flag.AGENT_BRIEFING]: false, }; type FlagValues = typeof defaultFlags; diff --git a/docs/integrations/block-integrations/llm.md b/docs/integrations/block-integrations/llm.md index 77da6fd5d0..e0d39ed302 100644 --- a/docs/integrations/block-integrations/llm.md +++ b/docs/integrations/block-integrations/llm.md @@ -65,7 +65,7 @@ The result routes data to yes_output or no_output, enabling intelligent branchin | condition | A plaintext English description of the condition to evaluate | str | Yes | | yes_value | (Optional) Value to output if the condition is true. If not provided, input_value will be used. | Yes Value | No | | no_value | (Optional) Value to output if the condition is false. If not provided, input_value will be used. | No Value | No | -| model | The language model to use for evaluating the condition. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for evaluating the condition. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | ### Outputs @@ -103,7 +103,7 @@ The block sends the entire conversation history to the chosen LLM, including sys |-------|-------------|------|----------| | prompt | The prompt to send to the language model. | str | No | | messages | List of messages in the conversation. | List[Any] | Yes | -| model | The language model to use for the conversation. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for the conversation. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | max_tokens | The maximum number of tokens to generate in the chat completion. | int | No | | ollama_host | Ollama host for local models | str | No | @@ -257,7 +257,7 @@ The block formulates a prompt based on the given focus or source data, sends it |-------|-------------|------|----------| | focus | The focus of the list to generate. | str | No | | source_data | The data to generate the list from. | str | No | -| model | The language model to use for generating the list. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for generating the list. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | max_retries | Maximum number of retries for generating a valid list. | int | No | | force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No | | max_tokens | The maximum number of tokens to generate in the chat completion. | int | No | @@ -424,7 +424,7 @@ The block sends the input prompt to a chosen LLM, along with any system prompts | prompt | The prompt to send to the language model. | str | Yes | | expected_format | Expected format of the response. If provided, the response will be validated against this format. The keys should be the expected fields in the response, and the values should be the description of the field. | Dict[str, str] | Yes | | list_result | Whether the response should be a list of objects in the expected format. | bool | No | -| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | force_json_output | Whether to force the LLM to produce a JSON-only response. This can increase the block's reliability, but may also reduce the quality of the response because it prohibits the LLM from reasoning before providing its JSON response. | bool | No | | sys_prompt | The system prompt to provide additional context to the model. | str | No | | conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No | @@ -464,7 +464,7 @@ The block sends the input prompt to a chosen LLM, processes the response, and re | Input | Description | Type | Required | |-------|-------------|------|----------| | prompt | The prompt to send to the language model. You can use any of the {keys} from Prompt Values to fill in the prompt with values from the prompt values dictionary by putting them in curly braces. | str | Yes | -| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | sys_prompt | The system prompt to provide additional context to the model. | str | No | | retry | Number of times to retry the LLM call if the response does not match the expected format. | int | No | | prompt_values | Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}. | Dict[str, str] | No | @@ -501,7 +501,7 @@ The block splits the input text into smaller chunks, sends each chunk to an LLM | Input | Description | Type | Required | |-------|-------------|------|----------| | text | The text to summarize. | str | Yes | -| model | The language model to use for summarizing the text. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for summarizing the text. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | focus | The topic to focus on in the summary | str | No | | style | The style of the summary to generate. | "concise" \| "detailed" \| "bullet points" \| "numbered list" | No | | max_tokens | The maximum number of tokens to generate in the chat completion. | int | No | @@ -721,7 +721,7 @@ _Add technical explanation here._ | Input | Description | Type | Required | |-------|-------------|------|----------| | prompt | The prompt to send to the language model. | str | Yes | -| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | +| model | The language model to use for answering the prompt. | "o3-mini" \| "o3-2025-04-16" \| "o1" \| "o1-mini" \| "gpt-5.2-2025-12-11" \| "gpt-5.1-2025-11-13" \| "gpt-5-2025-08-07" \| "gpt-5-mini-2025-08-07" \| "gpt-5-nano-2025-08-07" \| "gpt-5-chat-latest" \| "gpt-4.1-2025-04-14" \| "gpt-4.1-mini-2025-04-14" \| "gpt-4o-mini" \| "gpt-4o" \| "gpt-4-turbo" \| "claude-opus-4-1-20250805" \| "claude-opus-4-20250514" \| "claude-sonnet-4-20250514" \| "claude-opus-4-5-20251101" \| "claude-sonnet-4-5-20250929" \| "claude-haiku-4-5-20251001" \| "claude-opus-4-6" \| "claude-sonnet-4-6" \| "claude-3-haiku-20240307" \| "Qwen/Qwen2.5-72B-Instruct-Turbo" \| "nvidia/llama-3.1-nemotron-70b-instruct" \| "meta-llama/Llama-3.3-70B-Instruct-Turbo" \| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" \| "meta-llama/Llama-3.2-3B-Instruct-Turbo" \| "llama-3.3-70b-versatile" \| "llama-3.1-8b-instant" \| "llama3.3" \| "llama3.2" \| "llama3" \| "llama3.1:405b" \| "dolphin-mistral:latest" \| "openai/gpt-oss-120b" \| "openai/gpt-oss-20b" \| "google/gemini-2.5-pro-preview-03-25" \| "google/gemini-2.5-pro" \| "google/gemini-3.1-pro-preview" \| "google/gemini-3-flash-preview" \| "google/gemini-2.5-flash" \| "google/gemini-2.0-flash-001" \| "google/gemini-3.1-flash-lite-preview" \| "google/gemini-2.5-flash-lite-preview-06-17" \| "google/gemini-2.0-flash-lite-001" \| "mistralai/mistral-nemo" \| "mistralai/mistral-large-2512" \| "mistralai/mistral-medium-3.1" \| "mistralai/mistral-small-3.2-24b-instruct" \| "mistralai/codestral-2508" \| "cohere/command-r-08-2024" \| "cohere/command-r-plus-08-2024" \| "cohere/command-a-03-2025" \| "cohere/command-a-translate-08-2025" \| "cohere/command-a-reasoning-08-2025" \| "cohere/command-a-vision-07-2025" \| "deepseek/deepseek-chat" \| "deepseek/deepseek-r1-0528" \| "perplexity/sonar" \| "perplexity/sonar-pro" \| "perplexity/sonar-reasoning-pro" \| "perplexity/sonar-deep-research" \| "nousresearch/hermes-3-llama-3.1-405b" \| "nousresearch/hermes-3-llama-3.1-70b" \| "amazon/nova-lite-v1" \| "amazon/nova-micro-v1" \| "amazon/nova-pro-v1" \| "microsoft/wizardlm-2-8x22b" \| "microsoft/phi-4" \| "gryphe/mythomax-l2-13b" \| "meta-llama/llama-4-scout" \| "meta-llama/llama-4-maverick" \| "x-ai/grok-3" \| "x-ai/grok-4" \| "x-ai/grok-4-fast" \| "x-ai/grok-4.1-fast" \| "x-ai/grok-4.20" \| "x-ai/grok-4.20-multi-agent" \| "x-ai/grok-code-fast-1" \| "moonshotai/kimi-k2" \| "qwen/qwen3-235b-a22b-thinking-2507" \| "qwen/qwen3-coder" \| "z-ai/glm-4-32b" \| "z-ai/glm-4.5" \| "z-ai/glm-4.5-air" \| "z-ai/glm-4.5-air:free" \| "z-ai/glm-4.5v" \| "z-ai/glm-4.6" \| "z-ai/glm-4.6v" \| "z-ai/glm-4.7" \| "z-ai/glm-4.7-flash" \| "z-ai/glm-5" \| "z-ai/glm-5-turbo" \| "z-ai/glm-5v-turbo" \| "Llama-4-Scout-17B-16E-Instruct-FP8" \| "Llama-4-Maverick-17B-128E-Instruct-FP8" \| "Llama-3.3-8B-Instruct" \| "Llama-3.3-70B-Instruct" \| "v0-1.5-md" \| "v0-1.5-lg" \| "v0-1.0-md" | No | | multiple_tool_calls | Whether to allow multiple tool calls in a single response. | bool | No | | sys_prompt | The system prompt to provide additional context to the model. | str | No | | conversation_history | The conversation history to provide context for the prompt. | List[Dict[str, Any]] | No | diff --git a/docs/integrations/block-integrations/misc.md b/docs/integrations/block-integrations/misc.md index ef7fd938db..c494903c38 100644 --- a/docs/integrations/block-integrations/misc.md +++ b/docs/integrations/block-integrations/misc.md @@ -58,7 +58,7 @@ Tool and block identifiers provided in `tools` and `blocks` are validated at run | system_context | Optional additional context prepended to the prompt. Use this to constrain autopilot behavior, provide domain context, or set output format requirements. | str | No | | session_id | Session ID to continue an existing autopilot conversation. Leave empty to start a new session. Use the session_id output from a previous run to continue. | str | No | | max_recursion_depth | Maximum nesting depth when the autopilot calls this block recursively (sub-agent pattern). Prevents infinite loops. | int | No | -| tools | Tool names to filter. Works with tools_exclude to form an allow-list or deny-list. Leave empty to apply no tool filter. | List["add_understanding" \| "ask_question" \| "bash_exec" \| "browser_act" \| "browser_navigate" \| "browser_screenshot" \| "connect_integration" \| "continue_run_block" \| "create_agent" \| "create_feature_request" \| "create_folder" \| "customize_agent" \| "delete_folder" \| "delete_workspace_file" \| "edit_agent" \| "find_agent" \| "find_block" \| "find_library_agent" \| "fix_agent_graph" \| "get_agent_building_guide" \| "get_doc_page" \| "get_mcp_guide" \| "list_folders" \| "list_workspace_files" \| "memory_search" \| "memory_store" \| "move_agents_to_folder" \| "move_folder" \| "read_workspace_file" \| "run_agent" \| "run_block" \| "run_mcp_tool" \| "search_docs" \| "search_feature_requests" \| "update_folder" \| "validate_agent_graph" \| "view_agent_output" \| "web_fetch" \| "write_workspace_file" \| "Agent" \| "Edit" \| "Glob" \| "Grep" \| "Read" \| "Task" \| "TodoWrite" \| "WebSearch" \| "Write"] | No | +| tools | Tool names to filter. Works with tools_exclude to form an allow-list or deny-list. Leave empty to apply no tool filter. | List["add_understanding" \| "ask_question" \| "bash_exec" \| "browser_act" \| "browser_navigate" \| "browser_screenshot" \| "connect_integration" \| "continue_run_block" \| "create_agent" \| "create_feature_request" \| "create_folder" \| "customize_agent" \| "delete_folder" \| "delete_workspace_file" \| "edit_agent" \| "find_agent" \| "find_block" \| "find_library_agent" \| "fix_agent_graph" \| "get_agent_building_guide" \| "get_doc_page" \| "get_mcp_guide" \| "list_folders" \| "list_workspace_files" \| "memory_forget_confirm" \| "memory_forget_search" \| "memory_search" \| "memory_store" \| "move_agents_to_folder" \| "move_folder" \| "read_workspace_file" \| "run_agent" \| "run_block" \| "run_mcp_tool" \| "search_docs" \| "search_feature_requests" \| "update_folder" \| "validate_agent_graph" \| "view_agent_output" \| "web_fetch" \| "write_workspace_file" \| "Agent" \| "Edit" \| "Glob" \| "Grep" \| "Read" \| "Task" \| "TodoWrite" \| "WebSearch" \| "Write"] | No | | tools_exclude | Controls how the 'tools' list is interpreted. True (default): 'tools' is a deny-list — listed tools are blocked, all others are allowed. An empty 'tools' list means allow everything. False: 'tools' is an allow-list — only listed tools are permitted. | bool | No | | blocks | Block identifiers to filter when the copilot uses run_block. Each entry can be: a block name (e.g. 'HTTP Request'), a full block UUID, or the first 8 hex characters of the UUID (e.g. 'c069dc6b'). Works with blocks_exclude. Leave empty to apply no block filter. | List[str] | No | | blocks_exclude | Controls how the 'blocks' list is interpreted. True (default): 'blocks' is a deny-list — listed blocks are blocked, all others are allowed. An empty 'blocks' list means allow everything. False: 'blocks' is an allow-list — only listed blocks are permitted. | bool | No |