From 1c0c7a6b44dc5831930091b95fbdf9bafef253b4 Mon Sep 17 00:00:00 2001 From: Toran Bruce Richards Date: Fri, 17 Apr 2026 16:22:10 +0100 Subject: [PATCH 01/41] fix(copilot): add gh auth status check to Tool Discovery Priority section (#12832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem The CoPilot system prompt contains a `gh auth status` instruction in the E2B-specific `GitHub CLI` section, but models pattern-match to `connect_integration` from the **Tool Discovery Priority** section — which is where the actual decision to call an external service is made. Because the GitHub auth check lives in a separate, later section, it's not salient at the point of decision-making. This causes the model to call `connect_integration(provider='github')` even when `gh` is already authenticated via `GH_TOKEN`, unnecessarily prompting the user. ## Fix Add a 3-line callout directly inside the **Tool Discovery Priority** section: ``` > 🔑 **GitHub exception:** Before calling `connect_integration` for GitHub, > always run `gh auth status` first. If it shows `Logged in`, proceed > directly with `gh`/`git` — no integration connection needed. ``` This places the rule at the exact location where the model decides which tool path to take, preventing the miss. ## Why this works - **Placement over repetition**: The existing instruction isn't wrong — it's just in the wrong spot relative to where the decision is made - **Negative framing**: Explicitly says "before calling `connect_integration`" which directly intercepts the incorrect reflex - **Minimal change**: 4 lines added, zero removed Co-authored-by: Toran Bruce Richards <22963551+Torantulino@users.noreply.github.com> --- autogpt_platform/backend/backend/copilot/prompting.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index ed436733dd..95339cc2ce 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -174,14 +174,18 @@ sandbox so `bash_exec` can access it for further processing. The exact sandbox path is shown in the `[Sandbox copy available at ...]` note. ### GitHub CLI (`gh`) and git -- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it. +- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected. - If the user has connected their GitHub account, both `gh` and `git` are pre-authenticated — use them directly without any manual login step. `git` HTTPS operations (clone, push, pull) work automatically. - If the token changes mid-session (e.g. user reconnects with a new token), run `gh auth setup-git` to re-register the credential helper. -- If `gh` or `git` fails with an authentication error (e.g. "authentication - required", "could not read Username", or exit code 128), call +- **MANDATORY:** You MUST run `gh auth status` before EVER calling + `connect_integration(provider="github")`. If it shows `Logged in`, + proceed directly — no integration connection needed. Never skip this check. +- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an + authentication error (e.g. "authentication required", "could not read + Username", or exit code 128), THEN call `connect_integration(provider="github")` to surface the GitHub credentials setup card so the user can connect their account. Once connected, retry the operation. From a8226af7259e2b7c43ad2185407f495c47480933 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 10:18:52 +0700 Subject: [PATCH 02/41] fix(copilot): dedupe tool row, lift bash_exec timeout, Stop+resend recovery (#12862) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #12861 · [OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096) ## Why Four related copilot UX / stability issues surfaced on dev once action tools started rendering inline in the chat (see #12813): ### 1. Duplicate bash_exec row `GenericTool` rendered two rows saying the same thing for every completed tool call — a muted subtitle line ("Command exited with code 1" / "Ran: sleep 20") **and** a `ToolAccordion` with the command echoed in its description. Previously hidden inside the "Show reasoning" / "Show steps" collapse, now visibly duplicated. ### 2. `bash_exec` capped at 120s via advisory text The tool schema said `"Max seconds (default 30, max 120)"`; the model obeyed, so long-running scripts got clipped at 120s with a vague `Timed out after 120s` even though the E2B sandbox has no such limit. Confirmed via Langfuse traces — the model picks `120` for long scripts because that's what the schema told it the max was. E2B path never had a server-side clamp. Originally added in #12103 (default 30) and tightened to "max 120" advisory in #12398 (token-reduction pass). ### 3. 30s default was too aggressive `pip install`, small data-processing scripts, etc. routinely cross 30s and got killed before the model thought to retry with a bigger timeout. ### 4. Stop + edit + resend → "The assistant encountered an error" ([OPEN-3096](https://linear.app/autogpt/issue/OPEN-3096)) Two independent bugs both land on the same banner — fixing only one leaves the other visible on the next action. **4a. Stream lock never released on Stop** *(the error in the ticket screenshot)*. The executor's `async for chunk in stream_and_publish(...)` broke out on `cancel.is_set()` without calling `aclose()` on the wrapper. `async for` does NOT auto-close iterators on `break`, so `stream_chat_completion_sdk` stayed suspended at its current `await` — still holding the per-session Redis lock (TTL 120s) until GC eventually closed it. The next `POST /stream` hit `lock.try_acquire()` at [sdk/service.py](autogpt_platform/backend/backend/copilot/sdk/service.py) and yielded `StreamError("Another stream is already active for this session. Please wait or stop it.")`. The `except GeneratorExit → lock.release()` handler written exactly for this case never fired because nothing sent GeneratorExit. **4b. Orphan `tool_use` after stop-mid-tool.** Even with the lock released, the stop path persists the session ending on an assistant row whose `tool_calls` have no matching `role="tool"` row. On the next turn, `_session_messages_to_transcript` hands Claude CLI `--resume` a JSONL with a `tool_use` and no paired `tool_result`, and the SDK raises a vague error — same banner. The ticket's "Open questions" explicitly flags this. ## What **Frontend — `GenericTool.tsx`** split responsibilities between the two rows so they don't duplicate: - **Subtitle row** (always visible, muted): *what ran* — `Ran: sleep 120`. Never the exit code. - **Accordion description**: *how it ended* — `completed` / `status code 127 · bash: missing-bin: command not found` / `Timed out after 120s` / (fallback to command preview for legacy rows missing `exit_code` / `timed_out`). Pulled from the first non-empty line of `stdout` / `stderr` when available. - **Expanded accordion**: full command + stdout + stderr code blocks (unchanged). **Backend — `bash_exec.py`**: - Drop the "max 120" advisory from the schema description. - Bump default `timeout: 30 → 120`. - Clean up the result message — `"Command executed with status code 0"` (no "on E2B", no parens). **Backend — `executor/processor.py` + `stream_registry.py` (OPEN-3096 #4a)**: wrap the consumer `async for` in `try/finally: await stream.aclose()`. Close now propagates through `stream_and_publish` into `stream_chat_completion_sdk`, whose existing `except GeneratorExit → lock.release()` releases the Redis lock immediately on cancel. Stream types tightened to `AsyncGenerator[StreamBaseResponse, None]` so the defensive `getattr(stream, "aclose", None)` goes away. **Backend — `session_cleanup.py` (OPEN-3096 #4b)**: new `prune_orphan_tool_calls()` helper walks the trailing session tail and drops any trailing assistant row whose `tool_calls` have unresolved ids (plus everything after it) and any trailing `STOPPED_BY_USER_MARKER` system-stop row. Single backward pass — tolerates the marker being present or absent. Called from the existing turn-start cleanup in both `sdk/service.py` and `baseline/service.py`; takes an optional `log_prefix` so both paths emit the same INFO log when something was popped. In-memory only — the DB save path is append-only via `start_sequence`. ## Test plan - [x] `pnpm exec vitest run src/app/(platform)/copilot/tools/GenericTool src/app/(platform)/copilot/components/ChatMessagesContainer` — 105 pass (6 new for GenericTool subtitle/description variants + legacy-fallback case). - [x] `pnpm format` / `pnpm lint` / `pnpm types` — clean. - [x] `poetry run pytest backend/copilot/sdk/session_persistence_test.py` — 17 pass (6 + 3 new covering the orphan-tool-call prune and its optional-log-prefix branch). - [x] `poetry run pytest backend/copilot/stream_registry_test.py backend/copilot/executor/processor_test.py` — 19 pass (2 for aclose propagation on the `stream_and_publish` wrapper, 2 for `_execute_async` aclose propagation on both exit paths, 1 for publish_chunk RedisError warning ladder). - [x] `poetry run ruff check` / `poetry run pyright` on touched files — clean. - [x] Manual: fire a `bash_exec` — one labelled row, accordion description reads sensibly (`completed` / `status code 1 · …` / `Timed out after 120s`). - [x] Manual: script that needs >120s — no longer clipped. - [x] Manual: Stop mid-tool + edit + resend — Autopilot resumes without "Another stream is already active" and without the vague SDK error. ## Scope note Does not touch `splitReasoningAndResponse` — re-collapsing action tools back into "Show steps" is #12813's responsibility. --- .../backend/copilot/baseline/service.py | 9 +- .../backend/backend/copilot/constants.py | 5 + .../backend/copilot/executor/processor.py | 38 ++-- .../copilot/executor/processor_test.py | 104 +++++++++- .../copilot/sdk/response_adapter_test.py | 2 - .../backend/backend/copilot/sdk/service.py | 20 +- .../copilot/sdk/session_persistence_test.py | 182 ++++++++++++++++++ .../backend/copilot/session_cleanup.py | 77 ++++++++ .../backend/copilot/stream_registry.py | 56 +++--- .../backend/copilot/stream_registry_test.py | 113 +++++++++++ .../backend/copilot/tools/bash_exec.py | 12 +- .../copilot/tools/GenericTool/GenericTool.tsx | 33 +++- .../__tests__/GenericTool.test.tsx | 139 +++++++++++++ .../GenericTool/__tests__/helpers.test.ts | 4 +- .../copilot/tools/GenericTool/helpers.ts | 21 +- 15 files changed, 733 insertions(+), 82 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/session_cleanup.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/GenericTool.test.tsx diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 4c6ad04d60..7d27beac8b 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -38,7 +38,6 @@ from backend.copilot.model import ( from backend.copilot.pending_message_helpers import ( combine_pending_with_current, drain_pending_safe, - pending_texts_from, persist_pending_as_user_rows, persist_session_safe, ) @@ -70,6 +69,7 @@ from backend.copilot.service import ( inject_user_context, strip_user_context_tags, ) +from backend.copilot.session_cleanup import prune_orphan_tool_calls from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper from backend.copilot.token_tracking import persist_and_record_usage from backend.copilot.tools import execute_tool, get_available_tools @@ -948,6 +948,12 @@ async def stream_chat_completion_baseline( f"Session {session_id} not found. Please create a new session first." ) + # Drop orphan tool_use + trailing stop-marker rows left by a previous + # Stop mid-tool-call so the new turn starts from a well-formed message list. + prune_orphan_tool_calls( + session.messages, log_prefix=f"[Baseline] [{session_id[:12]}]" + ) + # Strip any user-injected tags on every turn. # Only the server-injected prefix on the first message is trusted. if message: @@ -982,7 +988,6 @@ async def stream_chat_completion_baseline( len(drained_at_start_pending), session_id, ) - drained_at_start_content = pending_texts_from(drained_at_start_pending) # Chronological combine: pending typed BEFORE this /stream # request's arrival go ahead of ``message``; race-path follow-ups # typed AFTER (queued while /stream was still processing) go diff --git a/autogpt_platform/backend/backend/copilot/constants.py b/autogpt_platform/backend/backend/copilot/constants.py index 9a7388ab1b..986a641c7e 100644 --- a/autogpt_platform/backend/backend/copilot/constants.py +++ b/autogpt_platform/backend/backend/copilot/constants.py @@ -9,6 +9,11 @@ COPILOT_RETRYABLE_ERROR_PREFIX = ( ) COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message +# Canonical marker appended as an assistant ChatMessage when the SDK stream +# ends without a ResultMessage (user hit Stop). Checked by exact equality +# at turn start so the next turn's --resume transcript doesn't carry it. +STOPPED_BY_USER_MARKER = f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user" + # Prefix for all synthetic IDs generated by CoPilot block execution. # Used to distinguish CoPilot-generated records from real graph execution records # in PendingHumanReview and other tables. diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index 8a25e1a1d9..f40264b70b 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -361,26 +361,34 @@ class CoPilotProcessor: permissions=entry.permissions, request_arrival_at=entry.request_arrival_at, ) - async for chunk in stream_registry.stream_and_publish( + published_stream = stream_registry.stream_and_publish( session_id=entry.session_id, turn_id=entry.turn_id, stream=raw_stream, - ): - if cancel.is_set(): - log.info("Cancel requested, breaking stream") - break + ) + # Explicit aclose() on early exit: ``async for … break`` does + # not close the generator, so GeneratorExit would never reach + # stream_chat_completion_sdk, leaving its stream lock held + # until GC eventually runs. + try: + async for chunk in published_stream: + if cancel.is_set(): + log.info("Cancel requested, breaking stream") + break - # Capture StreamError so mark_session_completed receives - # the error message (stream_and_publish yields but does - # not publish StreamError — that's done by mark_session_completed). - if isinstance(chunk, StreamError): - error_msg = chunk.errorText - break + # Capture StreamError so mark_session_completed receives + # the error message (stream_and_publish yields but does + # not publish StreamError — that's done by mark_session_completed). + if isinstance(chunk, StreamError): + error_msg = chunk.errorText + break - current_time = time.monotonic() - if current_time - last_refresh >= refresh_interval: - cluster_lock.refresh() - last_refresh = current_time + current_time = time.monotonic() + if current_time - last_refresh >= refresh_interval: + cluster_lock.refresh() + last_refresh = current_time + finally: + await published_stream.aclose() # Stream loop completed if cancel.is_set(): diff --git a/autogpt_platform/backend/backend/copilot/executor/processor_test.py b/autogpt_platform/backend/backend/copilot/executor/processor_test.py index f565c5a2b3..5541648747 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor_test.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor_test.py @@ -10,14 +10,18 @@ the real production helpers from ``processor.py`` so the routing logic has meaningful coverage. """ -from unittest.mock import AsyncMock, patch +import logging +import threading +from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.copilot.executor.processor import ( + CoPilotProcessor, resolve_effective_mode, resolve_use_sdk_for_mode, ) +from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata class TestResolveUseSdkForMode: @@ -173,3 +177,101 @@ class TestResolveEffectiveMode: ) as flag_mock: assert await resolve_effective_mode("fast", None) is None flag_mock.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# _execute_async aclose propagation +# --------------------------------------------------------------------------- + + +class _TrackedStream: + """Minimal async-generator stand-in that records whether ``aclose`` + was called, so tests can verify the processor forces explicit cleanup + of the published stream on every exit path (normal + break on cancel).""" + + def __init__(self, events: list): + self._events = events + self.aclose_called = False + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._events: + raise StopAsyncIteration + return self._events.pop(0) + + async def aclose(self) -> None: + self.aclose_called = True + + +def _make_entry() -> CoPilotExecutionEntry: + return CoPilotExecutionEntry( + session_id="sess-1", + turn_id="turn-1", + user_id="user-1", + message="hi", + is_user_message=True, + request_arrival_at=0.0, + ) + + +def _make_log() -> CoPilotLogMetadata: + return CoPilotLogMetadata(logger=logging.getLogger("test-copilot")) + + +class TestExecuteAsyncAclose: + """``_execute_async`` must call ``aclose`` on the published stream both + when the loop exits naturally and when ``cancel`` is set mid-stream — + otherwise ``stream_chat_completion_sdk`` stays suspended and keeps + holding the per-session Redis lock until GC.""" + + def _patches(self, published_stream: _TrackedStream): + """Shared mock context: patches every dependency ``_execute_async`` + touches so the aclose path is the only behaviour under test.""" + return [ + patch( + "backend.copilot.executor.processor.ChatConfig", + return_value=MagicMock(test_mode=True, use_claude_agent_sdk=True), + ), + patch( + "backend.copilot.executor.processor.stream_chat_completion_dummy", + return_value=MagicMock(), + ), + patch( + "backend.copilot.executor.processor.stream_registry.stream_and_publish", + return_value=published_stream, + ), + patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=AsyncMock(), + ), + ] + + @pytest.mark.asyncio + async def test_normal_exit_calls_aclose(self) -> None: + published = _TrackedStream(events=[MagicMock(), MagicMock()]) + proc = CoPilotProcessor() + cancel = threading.Event() + cluster_lock = MagicMock() + + patches = self._patches(published) + with patches[0], patches[1], patches[2], patches[3]: + await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log()) + + assert published.aclose_called is True + + @pytest.mark.asyncio + async def test_cancel_break_calls_aclose(self) -> None: + events = [MagicMock()] # first chunk delivered, then cancel fires + published = _TrackedStream(events=events) + proc = CoPilotProcessor() + cancel = threading.Event() + cancel.set() # pre-set so the loop breaks on the first chunk + cluster_lock = MagicMock() + + patches = self._patches(published) + with patches[0], patches[1], patches[2], patches[3]: + await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log()) + + assert published.aclose_called is True diff --git a/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py index c93286a3d6..634454f9e5 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py @@ -21,8 +21,6 @@ from backend.copilot.response_model import ( StreamFinishStep, StreamHeartbeat, StreamReasoningDelta, - StreamReasoningEnd, - StreamReasoningStart, StreamStart, StreamStartStep, StreamTextDelta, diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 8fea273b5d..ea0a135559 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -48,11 +48,12 @@ from ..config import ChatConfig, CopilotLlmModel, CopilotMode from ..constants import ( COPILOT_ERROR_PREFIX, COPILOT_RETRYABLE_ERROR_PREFIX, - COPILOT_SYSTEM_PREFIX, FRIENDLY_TRANSIENT_MSG, + STOPPED_BY_USER_MARKER, STREAM_IDLE_TIMEOUT_SECONDS, is_transient_api_error, ) +from ..session_cleanup import prune_orphan_tool_calls from ..context import encode_cwd_for_cli, get_workspace_manager from ..graphiti.config import is_enabled_for_user from ..model import ( @@ -70,7 +71,6 @@ from ..pending_message_helpers import ( persist_session_safe, ) from ..pending_messages import ( - PendingMessage, drain_pending_for_persist, push_pending_message, ) @@ -2504,10 +2504,7 @@ async def _run_stream_attempt( for r in closing_responses: yield r ctx.session.messages.append( - ChatMessage( - role="assistant", - content=f"{COPILOT_SYSTEM_PREFIX} Execution stopped by user", - ) + ChatMessage(role="assistant", content=STOPPED_BY_USER_MARKER) ) if ( @@ -2737,7 +2734,7 @@ async def stream_chat_completion_sdk( model: CopilotLlmModel | None = None, request_arrival_at: float = 0.0, **_kwargs: Any, -) -> AsyncIterator[StreamBaseResponse]: +) -> AsyncGenerator[StreamBaseResponse, None]: """Stream chat completion using Claude Agent SDK. Args: @@ -2781,6 +2778,10 @@ async def stream_chat_completion_sdk( ) session.messages.pop() + # Drop orphan tool_use + trailing stop-marker rows left by a previous + # Stop mid-tool-call so the next turn's --resume transcript is well-formed. + prune_orphan_tool_calls(session.messages, log_prefix=f"[SDK] [{session_id[:12]}]") + # Strip any user-injected tags on every turn. # Only the server-injected prefix on the first message is trusted. if message: @@ -3191,10 +3192,7 @@ async def stream_chat_completion_sdk( # Chronological combine: items typed BEFORE this request # arrived go ahead of ``current_message``; items typed AFTER # (race path, queued while /stream was still processing) go - # after. ``pending_texts`` is kept around because downstream - # code (the executor's update_message_content_by_sequence - # call) needs the pre-combine list. - pending_texts = pending_texts_from(pending_messages) + # after. current_message = combine_pending_with_current( pending_messages, current_message, diff --git a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py index ea7b128927..d7cbc1d24e 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py @@ -19,9 +19,11 @@ from __future__ import annotations from datetime import datetime, timezone from unittest.mock import MagicMock +from backend.copilot.constants import STOPPED_BY_USER_MARKER from backend.copilot.model import ChatMessage, ChatSession from backend.copilot.response_model import StreamStartStep, StreamTextDelta from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator +from backend.copilot.session_cleanup import prune_orphan_tool_calls _NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) @@ -215,3 +217,183 @@ class TestPreCreateAssistantMessage: _simulate_pre_create(acc, ctx) assert len(ctx.session.messages) == 0 + + +class TestPruneOrphanToolCalls: + """A Stop mid-tool-call leaves the session ending on an assistant row whose + ``tool_calls`` have no matching ``role="tool"`` row. Unless pruned before + the next turn, the ``--resume`` transcript would hand Claude CLI a + ``tool_use`` without a paired ``tool_result`` and the SDK would fail. + """ + + @staticmethod + def _tool_call(call_id: str, name: str = "bash_exec") -> dict: + return { + "id": call_id, + "type": "function", + "function": {"name": name, "arguments": "{}"}, + } + + def test_stop_mid_tool_leaves_orphan_assistant(self) -> None: + """Stop between StreamToolInputAvailable and StreamToolOutputAvailable: + the assistant row has ``tool_calls`` but no matching tool row.""" + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="do something"), + ChatMessage( + role="assistant", + content="", + tool_calls=[self._tool_call("tc_abc")], + ), + ] + + removed = prune_orphan_tool_calls(messages) + + assert removed == 1 + assert len(messages) == 1 + assert messages[-1].role == "user" + + def test_stop_strips_stopped_by_user_marker_and_orphan(self) -> None: + """The service also appends a ``STOPPED_BY_USER_MARKER`` after a + user stop when the stream loop exits cleanly; both tail rows must go.""" + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="do something"), + ChatMessage( + role="assistant", + content="", + tool_calls=[self._tool_call("tc_abc")], + ), + ChatMessage(role="assistant", content=STOPPED_BY_USER_MARKER), + ] + + removed = prune_orphan_tool_calls(messages) + + assert removed == 2 + assert len(messages) == 1 + assert messages[-1].role == "user" + + def test_completed_tool_call_is_preserved(self) -> None: + """An assistant row whose tool_calls are all resolved is a healthy + trailing state and must not be popped.""" + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="do something"), + ChatMessage( + role="assistant", + content="", + tool_calls=[self._tool_call("tc_abc")], + ), + ChatMessage( + role="tool", + content="ok", + tool_call_id="tc_abc", + ), + ] + + removed = prune_orphan_tool_calls(messages) + + assert removed == 0 + assert len(messages) == 3 + + def test_partial_resolution_still_pops(self) -> None: + """If an assistant emits multiple tool_calls and only some are + resolved, the assistant row is still unsafe for ``--resume``.""" + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="do something"), + ChatMessage( + role="assistant", + content="", + tool_calls=[ + self._tool_call("tc_1"), + self._tool_call("tc_2"), + ], + ), + ChatMessage( + role="tool", + content="ok", + tool_call_id="tc_1", + ), + ] + + removed = prune_orphan_tool_calls(messages) + + # Both the orphan assistant and its partial tool row are dropped. + assert removed == 2 + assert len(messages) == 1 + assert messages[-1].role == "user" + + def test_plain_assistant_text_preserved(self) -> None: + """A regular text-only assistant tail is healthy and must be kept.""" + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="hi"), + ChatMessage(role="assistant", content="hello"), + ] + + removed = prune_orphan_tool_calls(messages) + + assert removed == 0 + assert len(messages) == 2 + + def test_empty_session_is_noop(self) -> None: + messages: list[ChatMessage] = [] + assert prune_orphan_tool_calls(messages) == 0 + + +class TestPruneOrphanToolCallsLogging: + """``prune_orphan_tool_calls`` emits an INFO log when the caller passes + ``log_prefix`` and something was actually popped. Shared by the SDK + and baseline turn-start cleanup so both paths log in the same shape.""" + + def _tool_call(self, call_id: str) -> dict: + return {"id": call_id, "type": "function", "function": {"name": "bash"}} + + def test_logs_when_something_was_pruned(self, caplog) -> None: + import backend.copilot.session_cleanup as sc + + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="hi"), + ChatMessage( + role="assistant", content="", tool_calls=[self._tool_call("tc_1")] + ), + ] + + sc.logger.propagate = True + caplog.set_level("INFO", logger=sc.logger.name) + removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [abc123]") + + assert removed == 1 + assert any( + "[TEST] [abc123]" in r.message and "Dropped 1" in r.message + for r in caplog.records + ), caplog.text + + def test_no_log_when_nothing_to_prune(self, caplog) -> None: + import backend.copilot.session_cleanup as sc + + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="hi"), + ChatMessage(role="assistant", content="hello"), + ] + + sc.logger.propagate = True + caplog.set_level("INFO", logger=sc.logger.name) + removed = prune_orphan_tool_calls(messages, log_prefix="[TEST] [xyz]") + + assert removed == 0 + assert not any("[TEST] [xyz]" in r.message for r in caplog.records), caplog.text + + def test_no_log_when_log_prefix_is_none(self, caplog) -> None: + """Without ``log_prefix``, ``prune_orphan_tool_calls`` is silent.""" + import backend.copilot.session_cleanup as sc + + messages: list[ChatMessage] = [ + ChatMessage(role="user", content="hi"), + ChatMessage( + role="assistant", content="", tool_calls=[self._tool_call("tc_1")] + ), + ] + + sc.logger.propagate = True + caplog.set_level("INFO", logger=sc.logger.name) + removed = prune_orphan_tool_calls(messages) + + assert removed == 1 + assert caplog.text == "" diff --git a/autogpt_platform/backend/backend/copilot/session_cleanup.py b/autogpt_platform/backend/backend/copilot/session_cleanup.py new file mode 100644 index 0000000000..b23056ca68 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/session_cleanup.py @@ -0,0 +1,77 @@ +"""Pre-turn cleanup of transient markers left on ``session.messages`` by +prior turns (user-initiated Stop, cancelled tool calls, etc.). + +Shared by both the SDK and baseline chat entry points so both code paths +start every new turn from a well-formed message list. +""" + +import logging + +from backend.copilot.constants import STOPPED_BY_USER_MARKER +from backend.copilot.model import ChatMessage + +logger = logging.getLogger(__name__) + + +def prune_orphan_tool_calls( + messages: list[ChatMessage], + log_prefix: str | None = None, +) -> int: + """Pop trailing orphan tool-use blocks from *messages* in place. + + A Stop mid-tool-call leaves the session ending on an assistant message + whose ``tool_calls`` have no matching ``role="tool"`` row — the tool + never produced output because the executor was cancelled. Feeding that + tail to the next ``--resume`` turn would hand the Claude CLI a + ``tool_use`` with no paired ``tool_result`` and the SDK raises a + generic error. + + Also strips trailing ``STOPPED_BY_USER_MARKER`` assistant rows emitted + by the same Stop path so the next turn's transcript starts clean. + + If *log_prefix* is given, emits an INFO log with the prefix whenever + something was actually popped so the turn-start cleanup is visible. + + In-memory only — the DB write path is append-only via + ``start_sequence`` so no delete is needed; the same rows are popped + again on the next session load. + """ + cut_index: int | None = None + resolved_ids: set[str] = set() + + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + + if msg.role == "tool" and msg.tool_call_id: + resolved_ids.add(msg.tool_call_id) + continue + + if msg.role == "assistant" and msg.content == STOPPED_BY_USER_MARKER: + cut_index = i + continue + + if msg.role == "assistant" and msg.tool_calls: + pending_ids = { + tc.get("id") + for tc in msg.tool_calls + if isinstance(tc, dict) and tc.get("id") + } + if pending_ids and not pending_ids.issubset(resolved_ids): + cut_index = i + break + + break + + if cut_index is None: + return 0 + + removed = len(messages) - cut_index + del messages[cut_index:] + if log_prefix: + logger.info( + "%s Dropped %d trailing orphan tool-use/stop row(s) " + "before starting new turn", + log_prefix, + removed, + ) + return removed diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 111fbef90a..f4a26b7008 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -17,7 +17,7 @@ Subscribers: import asyncio import logging import time -from collections.abc import AsyncIterator +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Literal @@ -329,8 +329,8 @@ async def publish_chunk( async def stream_and_publish( session_id: str, turn_id: str, - stream: AsyncIterator[StreamBaseResponse], -) -> AsyncIterator[StreamBaseResponse]: + stream: AsyncGenerator[StreamBaseResponse, None], +) -> AsyncGenerator[StreamBaseResponse, None]: """Wrap an async stream iterator with registry publishing. Publishes each chunk to the stream registry for frontend SSE consumption, @@ -353,27 +353,35 @@ async def stream_and_publish( """ publish_failed_once = False - async for event in stream: - if turn_id and not isinstance(event, (StreamFinish, StreamError)): - try: - await publish_chunk(turn_id, event, session_id=session_id) - except (RedisError, ConnectionError, OSError): - if not publish_failed_once: - publish_failed_once = True - logger.warning( - "[stream_and_publish] Failed to publish chunk %s for %s " - "(further failures logged at DEBUG)", - type(event).__name__, - session_id[:12], - exc_info=True, - ) - else: - logger.debug( - "[stream_and_publish] Failed to publish chunk %s", - type(event).__name__, - exc_info=True, - ) - yield event + # async-for does not close an iterator on GeneratorExit; forward close + # to ``stream`` explicitly so its own cleanup (stream lock, persist) + # runs deterministically instead of waiting for GC. + try: + async for event in stream: + if turn_id and not isinstance(event, (StreamFinish, StreamError)): + try: + await publish_chunk(turn_id, event, session_id=session_id) + except (RedisError, ConnectionError, OSError): + # Full stack trace on the first failure; terser lines + # for the rest so subsequent failures don't flood logs + # while still being visible at WARNING. + if not publish_failed_once: + publish_failed_once = True + logger.warning( + "[stream_and_publish] Failed to publish chunk %s for %s", + type(event).__name__, + session_id[:12], + exc_info=True, + ) + else: + logger.warning( + "[stream_and_publish] Failed to publish chunk %s for %s", + type(event).__name__, + session_id[:12], + ) + yield event + finally: + await stream.aclose() async def subscribe_to_session( diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py index a09940a4a8..28ec199025 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry_test.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -108,3 +108,116 @@ async def test_disconnect_all_listeners_timeout_not_counted(): await task except asyncio.CancelledError: pass + + +# --------------------------------------------------------------------------- +# stream_and_publish: closing the wrapper forwards GeneratorExit into the +# inner stream so its finally (stream lock release, etc.) runs deterministically. +# --------------------------------------------------------------------------- + + +class _FakeEvent: + """Minimal stand-in for a StreamBaseResponse so publish_chunk is a no-op.""" + + def __init__(self, idx: int): + self.idx = idx + + +@pytest.mark.asyncio +async def test_stream_and_publish_aclose_propagates_to_inner_stream(): + """Closing the wrapper MUST run the inner generator's finally block.""" + inner_finally_ran = asyncio.Event() + + async def _inner(): + try: + yield _FakeEvent(0) + yield _FakeEvent(1) + yield _FakeEvent(2) + finally: + inner_finally_ran.set() + + inner = _inner() + # Empty turn_id skips publish_chunk — keeps the test hermetic (no Redis). + wrapper = stream_registry.stream_and_publish( + session_id="sess-test", turn_id="", stream=inner + ) + + # Consume one event, then close the wrapper early. + first = await wrapper.__anext__() + assert isinstance(first, _FakeEvent) + + await wrapper.aclose() + + # The inner generator's finally must have run deterministically + # (not deferred to GC) so the caller's cleanup (lock release, etc.) + # is observable right after aclose returns. + assert inner_finally_ran.is_set() + + +@pytest.mark.asyncio +async def test_stream_and_publish_logs_warning_on_publish_chunk_failure(): + """``stream_and_publish`` must not propagate a Redis publish failure — + it warns once with full stack trace, keeps yielding, and logs + subsequent failures at WARNING (terser, no exc_info) so repeated + errors stay visible without flooding the trace.""" + from redis.exceptions import RedisError + + async def _inner(): + yield _FakeEvent(0) + yield _FakeEvent(1) + yield _FakeEvent(2) + + async def _raising_publish(turn_id, event, session_id=None): + raise RedisError("boom") + + warning_mock = patch.object( + stream_registry.logger, "warning", autospec=True + ).start() + try: + with patch.object(stream_registry, "publish_chunk", new=_raising_publish): + wrapper = stream_registry.stream_and_publish( + session_id="sess-test", turn_id="turn-1", stream=_inner() + ) + received = [evt async for evt in wrapper] + finally: + patch.stopall() + + # Every event still yields through — publish failures don't break the stream. + assert len(received) == 3 + # One warning per failed publish (3 total). First call carries a + # stack trace (``exc_info=True``); subsequent calls are terser. + assert warning_mock.call_count == 3 + assert warning_mock.call_args_list[0].kwargs.get("exc_info") is True + assert warning_mock.call_args_list[1].kwargs.get("exc_info") is not True + + +@pytest.mark.asyncio +async def test_stream_and_publish_consumer_break_then_aclose_releases_inner(): + """The processor pattern — break on cancel, then aclose — must release.""" + inner_finally_ran = asyncio.Event() + + async def _inner(): + try: + for idx in range(100): + yield _FakeEvent(idx) + finally: + inner_finally_ran.set() + + inner = _inner() + wrapper = stream_registry.stream_and_publish( + session_id="sess-test", turn_id="", stream=inner + ) + + # Mimic the processor: consume a few events, simulate Stop by breaking, + # then aclose the wrapper (as processor._execute_async now does in the + # try/finally around the async for). + try: + count = 0 + async for _ in wrapper: + count += 1 + if count >= 2: + break + finally: + await wrapper.aclose() + + assert inner_finally_ran.is_set() diff --git a/autogpt_platform/backend/backend/copilot/tools/bash_exec.py b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py index ee87386cdb..1fbf4adc9c 100644 --- a/autogpt_platform/backend/backend/copilot/tools/bash_exec.py +++ b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py @@ -47,7 +47,7 @@ class BashExecTool(BaseTool): return ( "Execute a Bash command or script. Shares filesystem with SDK file tools. " "Useful for scripts, data processing, and package installation. " - "Killed after timeout (default 30s, max 120s)." + "Killed after `timeout` seconds." ) @property @@ -61,8 +61,8 @@ class BashExecTool(BaseTool): }, "timeout": { "type": "integer", - "description": "Max seconds (default 30, max 120).", - "default": 30, + "description": "Timeout in seconds; raise for long-running commands.", + "default": 120, }, }, "required": ["command"], @@ -80,7 +80,7 @@ class BashExecTool(BaseTool): user_id: str | None, session: ChatSession, command: str = "", - timeout: int = 30, + timeout: int = 120, **kwargs: Any, ) -> ToolResponseBase: """Run a bash command on E2B (if available) or in a bubblewrap sandbox. @@ -129,7 +129,7 @@ class BashExecTool(BaseTool): message=( "Execution timed out" if timed_out - else f"Command executed (exit {exit_code})" + else f"Command executed with status code {exit_code}" ), stdout=stdout, stderr=stderr, @@ -183,7 +183,7 @@ class BashExecTool(BaseTool): stdout = stdout.replace(secret, "[REDACTED]") stderr = stderr.replace(secret, "[REDACTED]") return BashExecResponse( - message=f"Command executed on E2B (exit {result.exit_code})", + message=f"Command executed with status code {result.exit_code}", stdout=stdout, stderr=stderr, exit_code=result.exit_code, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx index c897da9bdb..995c18df05 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -236,9 +236,39 @@ function getBashAccordionData( ? `Command failed (exit ${exitCode})` : "Command output"; + // The command itself is already in the subtitle row above; surface the + // outcome here so scanning the closed accordion tells the reader "how it + // ended" at a glance. Prefer the backend's own first line of output + // (stderr for failures/timeouts — that's where bash_exec writes + // "Timed out after Xs" and where shells emit "command not found" etc., + // stdout for success) over a terse "exit N" so the reader actually sees + // WHY the command ended. + const firstNonEmptyLine = (s: string | null): string | null => { + if (!s) return null; + const line = s.split("\n").find((l) => l.trim().length > 0); + return line ? truncate(line.trim(), 80) : null; + }; + const stderrPreview = firstNonEmptyLine(stderr); + const stdoutPreview = firstNonEmptyLine(stdout); + let description: string | undefined; + if (timedOut) { + description = stderrPreview ?? "timed out"; + } else if (exitCode !== null && exitCode !== 0) { + description = stderrPreview + ? `status code ${exitCode} · ${stderrPreview}` + : `status code ${exitCode}`; + } else if (exitCode === 0) { + description = stdoutPreview ?? "completed"; + } else { + // Historical sessions persisted before exit_code/timed_out were added + // fall through here — fall back to the command preview so the closed + // accordion still tells the reader what ran. + description = truncate(command, 80); + } + return { title, - description: truncate(command, 80), + description, content: (
{command && ( @@ -703,7 +733,6 @@ export function GenericTool({ part }: Props) { return (
- {/* Status line: always visible so the user sees what tool ran */}
= {}): ToolUIPart { + return { + type: "tool-bash_exec", + toolCallId: "call-1", + state: "input-streaming", + input: { command: 'echo "hi"' }, + ...overrides, + } as ToolUIPart; +} + +describe("GenericTool", () => { + it("shows a subtitle and no accordion while the tool is streaming", () => { + const { container } = render( + , + ); + expect(screen.queryByRole("button")).toBeNull(); + expect(container.textContent).toContain("Running"); + }); + + it("renders exactly one row once output is available (accordion only, no loose status line)", () => { + render( + , + ); + // The accordion trigger is the only interactive element; no separate + // MorphingTextAnimation status row is rendered alongside it. + const triggers = screen.getAllByRole("button"); + expect(triggers.length).toBe(1); + expect(triggers[0].textContent).toContain("Command failed (exit 1)"); + }); + + it("shows 'status code N · ' on non-zero exit", () => { + render( + , + ); + const trigger = screen.getByRole("button", { expanded: false }); + expect(trigger.textContent).toContain("Command failed (exit 127)"); + expect(trigger.textContent).toContain( + "status code 127 · bash: missing-bin: command not found", + ); + }); + + it("falls back to bare 'status code N' when stderr is empty", () => { + render( + , + ); + const trigger = screen.getByRole("button", { expanded: false }); + expect(trigger.textContent).toContain("status code 2"); + expect(trigger.textContent).not.toContain("·"); + }); + + it("shows the stderr first line for a timed-out command", () => { + render( + , + ); + const trigger = screen.getByRole("button", { expanded: false }); + expect(trigger.textContent).toContain("Command timed out"); + expect(trigger.textContent).toContain("Timed out after 120s"); + expect(trigger.textContent).not.toContain("sleep 120"); + }); + + it("falls back to the command preview for legacy outputs missing exit_code/timed_out", () => { + render( + , + ); + const trigger = screen.getByRole("button", { expanded: false }); + expect(trigger.textContent).toContain("echo hello"); + }); + + it("prefers stdout first line on exit 0, falls back to 'completed'", () => { + const { rerender } = render( + , + ); + const trigger1 = screen.getByRole("button", { expanded: false }); + expect(trigger1.textContent).toContain("Hello, world!"); + expect(trigger1.textContent).not.toContain("more lines below"); + + rerender( + , + ); + const trigger2 = screen.getByRole("button", { expanded: false }); + expect(trigger2.textContent).toContain("completed"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/helpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/helpers.test.ts index cc8bcc8afb..de0b9155b6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/helpers.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/helpers.test.ts @@ -202,14 +202,14 @@ describe("getAnimationText", () => { expect(getAnimationText(part, "bash")).toBe("Ran: echo hello"); }); - it("shows exit code on non-zero exit", () => { + it("still shows the command even on non-zero exit (exit code lives in the accordion description)", () => { const part = makePart({ type: "tool-bash_exec", state: "output-available", input: { command: "false" }, output: { exit_code: 1 }, }); - expect(getAnimationText(part, "bash")).toBe("Command exited with code 1"); + expect(getAnimationText(part, "bash")).toBe("Ran: false"); }); it("shows error text for bash failure", () => { diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/helpers.ts index f0a1cd6853..f8da6fbc2f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/helpers.ts @@ -199,17 +199,6 @@ export function humanizeFileName(filePath: string): string { return `"${words.join(" ")}"`; } -/* ------------------------------------------------------------------ */ -/* Exit code helper */ -/* ------------------------------------------------------------------ */ - -function getExitCode(output: unknown): number | null { - if (!output || typeof output !== "object") return null; - const parsed = output as Record; - if (typeof parsed.exit_code === "number") return parsed.exit_code; - return null; -} - /* ------------------------------------------------------------------ */ /* Animation text */ /* ------------------------------------------------------------------ */ @@ -287,13 +276,11 @@ export function getAnimationText( } case "output-available": { switch (category) { - case "bash": { - const exitCode = getExitCode(part.output); - if (exitCode !== null && exitCode !== 0) { - return `Command exited with code ${exitCode}`; - } + case "bash": + // Subtitle always shows WHAT ran. The accordion title + description + // carry HOW it ended (exit code / "timed out"), so repeating the + // exit status here would just double up. return shortSummary ? `Ran: ${shortSummary}` : "Command completed"; - } case "web": if (toolName === "WebSearch") { return shortSummary From 343222ace1568fdb25ef2bc6a3106baea1e3d7a5 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 14:01:09 +0700 Subject: [PATCH 03/41] feat(platform): defer paid-to-paid subscription downgrades + cancel-pending flow (#12865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** Only downgrades to FREE were scheduled at period end; paid→paid downgrades (e.g. BUSINESS→PRO) applied immediately via Stripe proration. The asymmetry meant users lost their higher tier mid-cycle in exchange for a Stripe credit voucher only redeemable on a future subscription — a confusing pattern that produces negative-value paths for users actually cancelling. There was also no way to cancel a pending downgrade or paid→FREE cancellation once scheduled. **What:** Standardize on "upgrade = immediate, downgrade = next cycle" and let users cancel a pending change by clicking their current tier. Harden the new code against conflicting subscription state, concurrent tab races, flaky Stripe calls, and hot-path latency regressions. **How:** Subscription state machine: - **Upgrade** (PRO→BUSINESS) — `stripe.Subscription.modify` with immediate proration (unchanged). If a downgrade schedule is already attached, release it first so the upgrade wins. - **Paid→paid downgrade** (BUSINESS→PRO) — creates a `stripe.SubscriptionSchedule` with two phases (current tier until `current_period_end`, target tier after). No mid-cycle tier demotion. Defensive pre-clear: existing schedule → release; `cancel_at_period_end=True` → set to False. - **Paid→FREE** — unchanged: `cancel_at_period_end=True`. - **Same-tier update** — reuses the existing `POST /credits/subscription` route. When `target_tier == current_tier`, backend calls `release_pending_subscription_schedule` (idempotent) and returns status. No dedicated cancel-pending endpoint — "Keep my current tier" IS the cancel operation. - `release_pending_subscription_schedule` is idempotent on terminal-state schedules and clears both `schedule` and `cancel_at_period_end` atomically per call. API surface: - New fields on `SubscriptionStatusResponse`: `pending_tier` + `pending_tier_effective_at` (pulled from the schedule's next-phase `start_date` so dashboard-authored schedules report the correct timestamp). - `POST /credits/subscription` now returns `SubscriptionStatusResponse` (previously `SubscriptionCheckoutResponse`); the response still carries `url` for checkout flows and adds the status fields inline. - `get_pending_subscription_change` is cached with a 30s TTL — avoids hammering Stripe on every home-page load. - Webhook dispatches `subscription_schedule.{released,completed,updated}` through the main `sync_subscription_from_stripe` flow so both event sources converge to the same DB state. Implementation notes: - New Stripe calls use native async (`stripe.Subscription.list_async` etc.) and typed attribute access — no `run_in_threadpool` wrapping in the new helpers. - Shared `_get_active_subscription` helper collapses the "list active/trialing subs, take first" pattern used by 4 callers. Frontend: - `PendingChangeBanner` sub-component above the tier grid with formatted effective date + "Keep [CurrentTier]" button. `aria-live="polite"` for screen readers; locale pinned to `en-US` to avoid SSR/CSR hydration mismatch. - "Keep [CurrentTier]" also available as a button on the current tier card. - Other tier buttons disabled while a change is pending — user must resolve pending first to prevent stacked schedules. - `cancelPendingChange` reuses `useUpdateSubscriptionTier` with `tier: current_tier`; awaits `refetch()` on both success and error paths so the UI reconciles even if the server succeeded but the client didn't receive the response. ### Changes **Backend (`credit.py`, `v1.py`)** - Tier-ordering helpers (`is_tier_upgrade`/`is_tier_downgrade`). - `modify_stripe_subscription_for_tier` routes downgrades through `_schedule_downgrade_at_period_end`; upgrade path releases any pending schedule first. - `_schedule_downgrade_at_period_end` defensively releases pre-existing schedules and clears `cancel_at_period_end` before creating the new schedule. - `release_pending_subscription_schedule` idempotent on terminal-state schedules; logs partial-failure outcomes. - `_next_phase_tier_and_start` returns both tier and phase-start timestamp; warns on unknown prices. - `get_pending_subscription_change` cached (30s TTL), narrow exception handling. - `sync_subscription_schedule_from_stripe` delegates to `sync_subscription_from_stripe` for convergence with the main webhook path. - Shared `_get_active_subscription` + `_release_schedule_ignoring_terminal` helpers. - `POST /credits/subscription` absorbs the same-tier "cancel pending change" branch. **Frontend (`SubscriptionTierSection/*`)** - `PendingChangeBanner` new sub-component (a11y, locale-pinned date, paid→FREE vs paid→paid copy split, non-null effective-date assertion, no `dark:` utilities). - "Keep [CurrentTier]" button on current tier card. - `useSubscriptionTierSection` — `cancelPendingChange` reuses the update-tier mutation. - Copy: downgrade dialog + status hint updated. - `helpers.ts` extracted from the main component. **Tests** - Backend: +24 tests (95/95 passing): upgrade-releases-pending-schedule, schedule-releases-existing-schedule, cancel-at-period-end collision, terminal-state release idempotency, unknown-price logging, status response population, same-tier-POST-with-pending, webhook delegation. - Frontend: +5 integration tests (21/21 passing): banner render/hide, Keep-button click from banner + current card, paid→paid dialog copy. ### Checklist - [x] Backend unit tests: 95 pass - [x] Frontend integration tests: 21 pass - [x] `poetry run format` / `poetry run lint` clean - [x] `pnpm format` / `pnpm lint` / `pnpm types` clean - [ ] Manual E2E on live Stripe (dev env) — pending deploy: BUSINESS→PRO creates schedule, DB tier unchanged until period end - [ ] Manual E2E: "Keep BUSINESS" in banner releases schedule - [ ] Manual E2E: cancel pending paid→FREE flips `cancel_at_period_end` back to false - [ ] Manual E2E: BUSINESS→PRO (scheduled) then attempt BUSINESS→FREE clears the PRO schedule, sets cancel_at_period_end - [ ] Manual E2E: BUSINESS→PRO (scheduled) then upgrade back to BUSINESS releases the schedule --- .../api/features/subscription_routes_test.py | 339 ++++- .../backend/backend/api/features/v1.py | 107 +- .../backend/backend/copilot/rate_limit.py | 124 +- .../backend/copilot/rate_limit_test.py | 74 + .../backend/backend/data/credit.py | 558 +++++++- .../backend/data/credit_subscription_test.py | 1274 ++++++++++++++++- .../SubscriptionTierSection.tsx | 154 +- .../SubscriptionTierSection.test.tsx | 235 ++- .../PendingChangeBanner.tsx | 60 + .../SubscriptionTierSection/helpers.ts | 54 + .../useSubscriptionTierSection.ts | 42 + .../frontend/src/app/api/openapi.json | 30 +- 12 files changed, 2907 insertions(+), 144 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts diff --git a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py index c20e0d0ceb..96fd8763eb 100644 --- a/autogpt_platform/backend/backend/api/features/subscription_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/subscription_routes_test.py @@ -47,6 +47,40 @@ def _configure_frontend_origin(mocker: pytest_mock.MockFixture) -> None: ) +@pytest.fixture(autouse=True) +def _stub_pending_subscription_change(mocker: pytest_mock.MockFixture) -> None: + """Default pending-change lookup to None so tests don't hit Stripe/DB. + + Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + +@pytest.fixture(autouse=True) +def _stub_subscription_status_lookups(mocker: pytest_mock.MockFixture) -> None: + """Stub Stripe price + proration lookups used by get_subscription_status. + + The POST /credits/subscription handler now returns the full subscription + status payload from every branch (same-tier, FREE downgrade, paid→paid + modify, checkout creation), so every POST test implicitly hits these + helpers. Individual tests can override via their own mocker.patch call. + """ + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + + @pytest.mark.parametrize( "url,expected", [ @@ -407,30 +441,77 @@ def test_update_subscription_tier_enterprise_blocked( set_tier_mock.assert_not_awaited() -def test_update_subscription_tier_same_tier_is_noop( +def test_update_subscription_tier_same_tier_releases_pending_change( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, ) -> None: - """POST /credits/subscription for the user's current paid tier returns 200 with empty URL. + """POST /credits/subscription for the user's current tier releases any pending change. - Without this guard a duplicate POST (double-click, browser retry, stale page) would - create a second Stripe Checkout Session for the same price, potentially billing the - user twice until the webhook reconciliation fires. + "Stay on my current tier" — the collapsed replacement for the old + /credits/subscription/cancel-pending route. Always calls + release_pending_subscription_schedule (idempotent when nothing is pending) + and returns the refreshed status with url="". Never creates a Checkout + Session — that would double-charge a user who double-clicks their own tier. """ mock_user = Mock() - mock_user.subscription_tier = SubscriptionTier.PRO - - async def mock_feature_enabled(*args, **kwargs): - return True + mock_user.subscription_tier = SubscriptionTier.BUSINESS mocker.patch( "backend.api.features.v1.get_user_by_id", new_callable=AsyncMock, return_value=mock_user, ) - mocker.patch( + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + feature_mock = mocker.patch( "backend.api.features.v1.is_feature_enabled", - side_effect=mock_feature_enabled, + new_callable=AsyncMock, + return_value=True, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["tier"] == "BUSINESS" + assert data["url"] == "" + release_mock.assert_awaited_once_with(TEST_USER_ID) + checkout_mock.assert_not_awaited() + # Same-tier branch short-circuits before the payment-flag check. + feature_mock.assert_not_awaited() + + +def test_update_subscription_tier_same_tier_no_pending_change_returns_status( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request when nothing is pending still returns status with url="".""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + release_mock = mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + new_callable=AsyncMock, + return_value=False, ) checkout_mock = mocker.patch( "backend.api.features.v1.create_subscription_checkout", @@ -447,10 +528,50 @@ def test_update_subscription_tier_same_tier_is_noop( ) assert response.status_code == 200 - assert response.json()["url"] == "" + data = response.json() + assert data["tier"] == "PRO" + assert data["url"] == "" + assert data["pending_tier"] is None + release_mock.assert_awaited_once_with(TEST_USER_ID) checkout_mock.assert_not_awaited() +def test_update_subscription_tier_same_tier_stripe_error_returns_502( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """Same-tier request surfaces a 502 when Stripe release fails. + + Carries forward the error contract from the removed + /credits/subscription/cancel-pending route so clients keep seeing 502 for + transient Stripe failures. + """ + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.release_pending_subscription_schedule", + side_effect=stripe.StripeError("network"), + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "BUSINESS", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 502 + assert "contact support" in response.json()["detail"].lower() + + def test_update_subscription_tier_free_with_payment_schedules_cancel_and_does_not_update_db( client: fastapi.testclient.TestClient, mocker: pytest_mock.MockFixture, @@ -803,3 +924,197 @@ def test_update_subscription_tier_free_no_stripe_subscription( cancel_mock.assert_awaited_once_with(TEST_USER_ID) # DB tier must be updated immediately — no webhook will fire for a missing sub set_tier_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.FREE) + + +def test_get_subscription_status_includes_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """GET /credits/subscription exposes pending_tier and pending_tier_effective_at.""" + import datetime as dt + + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + effective_at = dt.datetime(2030, 1, 1, tzinfo=dt.timezone.utc) + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + side_effect=mock_price_id, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=(SubscriptionTier.PRO, effective_at), + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] == "PRO" + assert data["pending_tier_effective_at"] is not None + + +def test_get_subscription_status_no_pending_tier( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """When no pending change exists the response omits pending_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.PRO + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.get_subscription_price_id", + new_callable=AsyncMock, + return_value=None, + ) + mocker.patch( + "backend.api.features.v1.get_proration_credit_cents", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.v1.get_pending_subscription_change", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.get("/credits/subscription") + + assert response.status_code == 200 + data = response.json() + assert data["pending_tier"] is None + assert data["pending_tier_effective_at"] is None + + +def test_update_subscription_tier_downgrade_paid_to_paid_schedules( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """A BUSINESS→PRO downgrade request dispatches to modify_stripe_subscription_for_tier.""" + mock_user = Mock() + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mocker.patch( + "backend.api.features.v1.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ) + mocker.patch( + "backend.api.features.v1.is_feature_enabled", + new_callable=AsyncMock, + return_value=True, + ) + modify_mock = mocker.patch( + "backend.api.features.v1.modify_stripe_subscription_for_tier", + new_callable=AsyncMock, + return_value=True, + ) + checkout_mock = mocker.patch( + "backend.api.features.v1.create_subscription_checkout", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/subscription", + json={ + "tier": "PRO", + "success_url": f"{TEST_FRONTEND_ORIGIN}/success", + "cancel_url": f"{TEST_FRONTEND_ORIGIN}/cancel", + }, + ) + + assert response.status_code == 200 + assert response.json()["url"] == "" + modify_mock.assert_awaited_once_with(TEST_USER_ID, SubscriptionTier.PRO) + checkout_mock.assert_not_awaited() + + +def test_stripe_webhook_dispatches_subscription_schedule_released( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.released routes to sync_subscription_schedule_from_stripe.""" + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.released", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_awaited_once_with(schedule_obj) + + +def test_stripe_webhook_ignores_subscription_schedule_updated( + client: fastapi.testclient.TestClient, + mocker: pytest_mock.MockFixture, +) -> None: + """subscription_schedule.updated must NOT dispatch: our own + SubscriptionSchedule.create/.modify calls fire this event and would + otherwise loop redundant traffic through the sync handler. State + transitions we care about surface via .released/.completed, and phase + advance to a new price is already covered by customer.subscription.updated. + """ + schedule_obj = {"id": "sub_sched_1", "subscription": "sub_pro"} + event = { + "type": "subscription_schedule.updated", + "data": {"object": schedule_obj}, + } + mocker.patch( + "backend.api.features.v1.settings.secrets.stripe_webhook_secret", + new="whsec_test", + ) + mocker.patch( + "backend.api.features.v1.stripe.Webhook.construct_event", + return_value=event, + ) + sync_mock = mocker.patch( + "backend.api.features.v1.sync_subscription_schedule_from_stripe", + new_callable=AsyncMock, + ) + + response = client.post( + "/credits/stripe_webhook", + content=b"{}", + headers={"stripe-signature": "t=1,v1=abc"}, + ) + + assert response.status_code == 200 + sync_mock.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index ab0b69071d..3559071043 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -26,7 +26,7 @@ from fastapi import ( ) from fastapi.concurrency import run_in_threadpool from prisma.enums import SubscriptionTier -from pydantic import BaseModel +from pydantic import BaseModel, Field from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from typing_extensions import Optional, TypedDict @@ -49,20 +49,24 @@ from backend.data.auth import api_key as api_key_db from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.credit import ( AutoTopUpConfig, + PendingChangeUnknown, RefundRequest, TransactionHistory, UserCredit, cancel_stripe_subscription, create_subscription_checkout, get_auto_top_up, + get_pending_subscription_change, get_proration_credit_cents, get_subscription_price_id, get_user_credit_model, handle_subscription_payment_failure, modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_auto_top_up, set_subscription_tier, sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) from backend.data.graph import GraphSettings from backend.data.model import CredentialsMetaInput, UserOnboarding @@ -698,15 +702,21 @@ class SubscriptionTierRequest(BaseModel): cancel_url: str = "" -class SubscriptionCheckoutResponse(BaseModel): - url: str - - class SubscriptionStatusResponse(BaseModel): tier: Literal["FREE", "PRO", "BUSINESS", "ENTERPRISE"] monthly_cost: int # amount in cents (Stripe convention) tier_costs: dict[str, int] # tier name -> amount in cents proration_credit_cents: int # unused portion of current sub to convert on upgrade + pending_tier: Optional[Literal["FREE", "PRO", "BUSINESS"]] = None + pending_tier_effective_at: Optional[datetime] = None + url: str = Field( + default="", + description=( + "Populated only when POST /credits/subscription starts a Stripe Checkout" + " Session (FREE → paid upgrade). Empty string in all other branches —" + " the client redirects to this URL when non-empty." + ), + ) def _validate_checkout_redirect_url(url: str) -> bool: @@ -804,17 +814,42 @@ async def get_subscription_status( current_monthly_cost = tier_costs.get(tier.value, 0) proration_credit = await get_proration_credit_cents(user_id, current_monthly_cost) - return SubscriptionStatusResponse( + try: + pending = await get_pending_subscription_change(user_id) + except (stripe.StripeError, PendingChangeUnknown): + # Swallow Stripe-side failures (rate limits, transient network) AND + # PendingChangeUnknown (LaunchDarkly price-id lookup failed). Both + # propagate past the cache so the next request retries fresh instead + # of serving a stale None for the TTL window. Let real bugs (KeyError, + # AttributeError, etc.) propagate so they surface in Sentry. + logger.exception( + "get_subscription_status: failed to resolve pending change for user %s", + user_id, + ) + pending = None + + response = SubscriptionStatusResponse( tier=tier.value, monthly_cost=current_monthly_cost, tier_costs=tier_costs, proration_credit_cents=proration_credit, ) + if pending is not None: + pending_tier_enum, pending_effective_at = pending + if pending_tier_enum == SubscriptionTier.FREE: + response.pending_tier = "FREE" + elif pending_tier_enum == SubscriptionTier.PRO: + response.pending_tier = "PRO" + elif pending_tier_enum == SubscriptionTier.BUSINESS: + response.pending_tier = "BUSINESS" + if response.pending_tier is not None: + response.pending_tier_effective_at = pending_effective_at + return response @v1_router.post( path="/credits/subscription", - summary="Start a Stripe Checkout session to upgrade subscription tier", + summary="Update subscription tier or start a Stripe Checkout session", operation_id="updateSubscriptionTier", tags=["credits"], dependencies=[Security(requires_user)], @@ -822,7 +857,7 @@ async def get_subscription_status( async def update_subscription_tier( request: SubscriptionTierRequest, user_id: Annotated[str, Security(get_user_id)], -) -> SubscriptionCheckoutResponse: +) -> SubscriptionStatusResponse: # Pydantic validates tier is one of FREE/PRO/BUSINESS via Literal type. tier = SubscriptionTier(request.tier) @@ -834,6 +869,29 @@ async def update_subscription_tier( detail="ENTERPRISE subscription changes must be managed by an administrator", ) + # Same-tier request = "stay on my current tier" = cancel any pending + # scheduled change (paid→paid downgrade or paid→FREE cancel). This is the + # collapsed behaviour that replaces the old /credits/subscription/cancel-pending + # route. Safe when no pending change exists: release_pending_subscription_schedule + # returns False and we simply return the current status. + if (user.subscription_tier or SubscriptionTier.FREE) == tier: + try: + await release_pending_subscription_schedule(user_id) + except stripe.StripeError as e: + logger.exception( + "Stripe error releasing pending subscription change for user %s: %s", + user_id, + e, + ) + raise HTTPException( + status_code=502, + detail=( + "Unable to cancel the pending subscription change right now. " + "Please try again or contact support." + ), + ) + return await get_subscription_status(user_id) + payment_enabled = await is_feature_enabled( Flag.ENABLE_PLATFORM_PAYMENT, user_id, default=False ) @@ -871,9 +929,9 @@ async def update_subscription_tier( # admin-granted tier. Update DB immediately since the # subscription.deleted webhook will never fire. await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) # Paid tier changes require payment to be enabled — block self-service upgrades # when the flag is off. Admins use the /api/admin/ routes to set tiers directly. @@ -883,15 +941,6 @@ async def update_subscription_tier( detail=f"Subscription not available for tier {tier}", ) - # No-op short-circuit: if the user is already on the requested paid tier, - # do NOT create a new Checkout Session. Without this guard, a duplicate - # request (double-click, retried POST, stale page) creates a second - # subscription for the same price; the user would be charged for both - # until `_cleanup_stale_subscriptions` runs from the resulting webhook — - # which only fires after the second charge has cleared. - if (user.subscription_tier or SubscriptionTier.FREE) == tier: - return SubscriptionCheckoutResponse(url="") - # Paid→paid tier change: if the user already has a Stripe subscription, # modify it in-place with proration instead of creating a new Checkout # Session. This preserves remaining paid time and avoids double-charging. @@ -901,14 +950,14 @@ async def update_subscription_tier( try: modified = await modify_stripe_subscription_for_tier(user_id, tier) if modified: - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) # modify_stripe_subscription_for_tier returns False when no active # Stripe subscription exists — i.e. the user has an admin-granted # paid tier with no Stripe record. In that case, update the DB # tier directly (same as the FREE-downgrade path for admin-granted # users) rather than sending them through a new Checkout Session. await set_subscription_tier(user_id, tier) - return SubscriptionCheckoutResponse(url="") + return await get_subscription_status(user_id) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) except stripe.StripeError as e: @@ -978,7 +1027,9 @@ async def update_subscription_tier( ), ) - return SubscriptionCheckoutResponse(url=url) + status = await get_subscription_status(user_id) + status.url = url + return status @v1_router.post( @@ -1043,6 +1094,18 @@ async def stripe_webhook(request: Request): ): await sync_subscription_from_stripe(data_object) + # `subscription_schedule.updated` is deliberately omitted: our own + # `SubscriptionSchedule.create` + `.modify` calls in + # `_schedule_downgrade_at_period_end` would fire that event right back at us + # and loop redundant traffic through this handler. We only care about state + # transitions (released / completed); phase advance to the new price is + # already covered by `customer.subscription.updated`. + if event_type in ( + "subscription_schedule.released", + "subscription_schedule.completed", + ): + await sync_subscription_schedule_from_stripe(data_object) + if event_type == "invoice.payment_failed": await handle_subscription_payment_failure(data_object) diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index 3124c28992..c08cb1b3a8 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -17,6 +17,7 @@ from redis.exceptions import RedisError from backend.data.db_accessors import user_db from backend.data.redis_client import get_redis_async +from backend.data.user import get_user_by_id from backend.util.cache import cached logger = logging.getLogger(__name__) @@ -459,8 +460,20 @@ get_user_tier.cache_delete = _fetch_user_tier.cache_delete # type: ignore[attr- async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None: """Persist the user's rate-limit tier to the database. - Also invalidates the ``get_user_tier`` cache for this user so that - subsequent rate-limit checks immediately see the new tier. + Invalidates every cache that keys off the user's subscription tier so the + change is visible immediately: this function's own ``get_user_tier``, the + shared ``get_user_by_id`` (which exposes ``user.subscription_tier``), and + ``get_pending_subscription_change`` (since an admin override can invalidate + a cached ``cancel_at_period_end`` or schedule-based pending state). + + If the user has an active Stripe subscription whose current price does not + match ``tier``, Stripe will keep billing the old price and the next + ``customer.subscription.updated`` webhook will overwrite the DB tier back + to whatever Stripe has. Proper reconciliation (cancelling or modifying the + Stripe subscription when an admin overrides the tier) is out of scope for + this PR — it changes the admin contract and needs its own test coverage. + For now we emit a ``WARNING`` so drift surfaces via Sentry until that + follow-up lands. Raises: prisma.errors.RecordNotFoundError: If the user does not exist. @@ -469,8 +482,113 @@ async def set_user_tier(user_id: str, tier: SubscriptionTier) -> None: where={"id": user_id}, data={"subscriptionTier": tier.value}, ) - # Invalidate cached tier so rate-limit checks pick up the change immediately. get_user_tier.cache_delete(user_id) # type: ignore[attr-defined] + # Local import required: backend.data.credit imports backend.copilot.rate_limit + # (via get_user_tier in credit.py's _invalidate_user_tier_caches), so a + # top-level ``from backend.data.credit import ...`` here would create a + # circular import at module-load time. + from backend.data.credit import get_pending_subscription_change + + get_user_by_id.cache_delete(user_id) # type: ignore[attr-defined] + get_pending_subscription_change.cache_delete(user_id) # type: ignore[attr-defined] + + # The DB write above is already committed; the drift check is best-effort + # diagnostic logging. Fire-and-forget so admin bulk ops don't wait on a + # Stripe roundtrip. The inner helper wraps its body in a timeout + broad + # except so background task errors still surface via logs rather than as + # "task exception never retrieved" warnings. Cancellation on request + # shutdown is acceptable — the drift warning is non-load-bearing. + asyncio.ensure_future(_drift_check_background(user_id, tier)) + + +async def _drift_check_background(user_id: str, tier: SubscriptionTier) -> None: + """Run the Stripe drift check in the background, logging rather than raising.""" + try: + await asyncio.wait_for( + _warn_if_stripe_subscription_drifts(user_id, tier), + timeout=5.0, + ) + logger.debug( + "set_user_tier: drift check completed for user=%s admin_tier=%s", + user_id, + tier.value, + ) + except asyncio.TimeoutError: + logger.warning( + "set_user_tier: drift check timed out for user=%s admin_tier=%s", + user_id, + tier.value, + ) + except asyncio.CancelledError: + # Request may have completed and the event loop is cancelling tasks — + # the drift log is non-critical, so accept cancellation silently. + raise + except Exception: + logger.exception( + "set_user_tier: drift check background task failed for" + " user=%s admin_tier=%s", + user_id, + tier.value, + ) + + +async def _warn_if_stripe_subscription_drifts( + user_id: str, new_tier: SubscriptionTier +) -> None: + """Emit a WARNING when an admin tier override leaves an active Stripe sub on a + mismatched price. + + The warning is diagnostic only: Stripe remains the billing source of truth, + so the next ``customer.subscription.updated`` webhook will reset the DB + tier. Surfacing the drift here lets ops catch admin overrides that bypass + the intended Checkout / Portal cancel flows before users notice surprise + charges. + """ + # Local imports: see note in ``set_user_tier`` about the credit <-> rate_limit + # circular. These helpers (``_get_active_subscription``, + # ``get_subscription_price_id``) live in credit.py alongside the rest of + # the Stripe billing code. + from backend.data.credit import _get_active_subscription, get_subscription_price_id + + try: + user = await get_user_by_id(user_id) + if not getattr(user, "stripe_customer_id", None): + return + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return + items = sub["items"].data + if not items: + return + price = items[0].price + current_price_id = price if isinstance(price, str) else price.id + # The LaunchDarkly-backed price lookup must live inside this try/except: + # an LD SDK failure (network, token revoked) here would otherwise + # propagate past set_user_tier's already-committed DB write and turn a + # best-effort diagnostic into a 500 on admin tier writes. + expected_price_id = await get_subscription_price_id(new_tier) + except Exception: + logger.debug( + "_warn_if_stripe_subscription_drifts: drift lookup failed for" + " user=%s; skipping drift warning", + user_id, + exc_info=True, + ) + return + if expected_price_id is not None and expected_price_id == current_price_id: + return + logger.warning( + "Admin tier override will drift from Stripe: user=%s admin_tier=%s" + " stripe_sub=%s stripe_price=%s expected_price=%s — the next" + " customer.subscription.updated webhook will reconcile the DB tier" + " back to whatever Stripe has; cancel or modify the Stripe subscription" + " if you intended the admin override to stick.", + user_id, + new_tier.value, + sub.id, + current_price_id, + expected_price_id, + ) async def get_global_rate_limits( diff --git a/autogpt_platform/backend/backend/copilot/rate_limit_test.py b/autogpt_platform/backend/backend/copilot/rate_limit_test.py index ea87658710..577093c752 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit_test.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit_test.py @@ -581,6 +581,80 @@ class TestSetUserTier: assert tier_after == SubscriptionTier.ENTERPRISE + @pytest.mark.asyncio + async def test_drift_check_swallows_launchdarkly_failure(self): + """LaunchDarkly price-id lookup failures inside the drift check must + never bubble up and 500 the admin tier write — the DB update is + already committed by the time we check drift.""" + mock_prisma = AsyncMock() + mock_prisma.update = AsyncMock(return_value=None) + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + mock_sub = MagicMock() + mock_sub.id = "sub_abc" + mock_sub["items"].data = [MagicMock(price=MagicMock(id="price_mismatch"))] + + with ( + patch( + "backend.copilot.rate_limit.PrismaUser.prisma", + return_value=mock_prisma, + ), + patch( + "backend.copilot.rate_limit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit._get_active_subscription", + new_callable=AsyncMock, + return_value=mock_sub, + ), + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + side_effect=RuntimeError("LD SDK not initialized"), + ), + ): + # Must NOT raise — drift check is best-effort diagnostic only. + await set_user_tier(_USER, SubscriptionTier.PRO) + + mock_prisma.update.assert_awaited_once() + + @pytest.mark.asyncio + async def test_drift_check_timeout_is_bounded(self): + """A Stripe call that stalls on the 80s SDK default must not block the + admin tier write — set_user_tier wraps the drift check in a 5s timeout + and logs + returns on TimeoutError.""" + import asyncio as _asyncio + + mock_prisma = AsyncMock() + mock_prisma.update = AsyncMock(return_value=None) + + async def _never_returns(_user_id: str, _tier): + await _asyncio.sleep(60) + + with ( + patch( + "backend.copilot.rate_limit.PrismaUser.prisma", + return_value=mock_prisma, + ), + patch( + "backend.copilot.rate_limit._warn_if_stripe_subscription_drifts", + side_effect=_never_returns, + ), + patch( + "backend.copilot.rate_limit.asyncio.wait_for", + new_callable=AsyncMock, + side_effect=_asyncio.TimeoutError, + ), + ): + await set_user_tier(_USER, SubscriptionTier.PRO) + + # Set_user_tier still completed — the drift timeout did not propagate. + mock_prisma.update.assert_awaited_once() + # --------------------------------------------------------------------------- # get_global_rate_limits with tiers diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index e97578d5cc..a42ba91be8 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -15,7 +15,7 @@ from prisma.enums import ( OnboardingStep, SubscriptionTier, ) -from prisma.errors import UniqueViolationError +from prisma.errors import PrismaError, UniqueViolationError from prisma.models import CreditRefundRequest, CreditTransaction, User, UserBalance from prisma.types import CreditRefundRequestCreateInput, CreditTransactionWhereInput from pydantic import BaseModel @@ -1280,6 +1280,12 @@ async def set_subscription_tier(user_id: str, tier: SubscriptionTier) -> None: from backend.copilot.rate_limit import get_user_tier # local import avoids circular get_user_tier.cache_delete(user_id) # type: ignore[attr-defined] + # Invalidate the pending-change cache too — an admin tier override or the + # webhook-driven phase transition means any cached pending-change state + # (schedule, cancel_at_period_end) is likely stale. Without this the + # billing page can show a pending change for up to 30s after the tier + # has already flipped. + get_pending_subscription_change.cache_delete(user_id) async def _cancel_customer_subscriptions( @@ -1330,6 +1336,21 @@ async def _cancel_customer_subscriptions( continue seen_ids.add(sub_id) if at_period_end: + # Stripe rejects modify(cancel_at_period_end=True) with 400 when a + # Subscription Schedule is attached (e.g. the user previously + # queued a paid→paid downgrade and is now clicking "Cancel"). + # Release the schedule first so the cancel flag can be set; the + # schedule's pending phase change is superseded by the cancel. + existing_schedule = sub.schedule + if existing_schedule: + schedule_id = ( + existing_schedule + if isinstance(existing_schedule, str) + else existing_schedule.id + ) + await _release_schedule_ignoring_terminal( + schedule_id, "_cancel_customer_subscriptions" + ) await run_in_threadpool( stripe.Subscription.modify, sub_id, cancel_at_period_end=True ) @@ -1366,6 +1387,8 @@ async def cancel_stripe_subscription(user_id: str) -> bool: cancelled_count = await _cancel_customer_subscriptions( customer_id, at_period_end=True ) + if cancelled_count > 0: + get_pending_subscription_change.cache_delete(user_id) return cancelled_count > 0 except stripe.StripeError: logger.warning( @@ -1415,18 +1438,224 @@ async def get_proration_credit_cents(user_id: str, monthly_cost_cents: int) -> i return 0 +# Ordered from least- to most-privileged. Used to distinguish upgrades +# (move right) from downgrades (move left); ENTERPRISE is admin-managed and +# never reached via self-service flows. +_TIER_ORDER: tuple[SubscriptionTier, ...] = ( + SubscriptionTier.FREE, + SubscriptionTier.PRO, + SubscriptionTier.BUSINESS, + SubscriptionTier.ENTERPRISE, +) + + +def _tier_rank(tier: SubscriptionTier) -> int: + return _TIER_ORDER.index(tier) + + +def is_tier_upgrade(current: SubscriptionTier, target: SubscriptionTier) -> bool: + return _tier_rank(target) > _tier_rank(current) + + +def is_tier_downgrade(current: SubscriptionTier, target: SubscriptionTier) -> bool: + return _tier_rank(target) < _tier_rank(current) + + +class PendingChangeUnknown(Exception): + """Raised when pending-change state cannot be determined (e.g. LaunchDarkly + price-id lookup failed). Propagates past the @cached wrapper so the next + request retries instead of serving a stale `None` for the TTL window.""" + + +async def _get_active_subscription(customer_id: str) -> stripe.Subscription | None: + """Return the customer's active or trialing subscription, or None.""" + for status in ("active", "trialing"): + subs = await stripe.Subscription.list_async( + customer=customer_id, status=status, limit=1 + ) + if subs.data: + return subs.data[0] + return None + + +# Substrings Stripe uses in InvalidRequestError messages when the schedule is +# already in a terminal state (released / completed / canceled) and therefore +# cannot be released again. We only swallow the error when one of these appears; +# anything else (typo'd schedule id, wrong subscription, 404, etc.) must +# propagate so bugs aren't masked as silent no-ops. +_TERMINAL_SCHEDULE_ERROR_SUBSTRINGS = ( + "already been released", + "already released", + "already been completed", + "already completed", + "already been canceled", + "already been cancelled", + "already canceled", + "already cancelled", + "is not active", + "is not in a state", +) + + +async def _release_schedule_ignoring_terminal( + schedule_id: str, log_context: str +) -> bool: + """Release a Stripe schedule; swallow InvalidRequestError on terminal state. + + Returns True if the release call succeeded, False if the schedule was + already in a terminal (released / completed / canceled) state. Any other + Stripe error — including non-terminal InvalidRequestErrors such as typo'd + ids or 404s — propagates so the caller can surface the failure instead of + silently masking a bug. + """ + try: + await stripe.SubscriptionSchedule.release_async(schedule_id) + return True + except stripe.InvalidRequestError as e: + message = getattr(e, "user_message", None) or str(e) + if not any( + marker in message.lower() for marker in _TERMINAL_SCHEDULE_ERROR_SUBSTRINGS + ): + logger.warning( + "%s: schedule %s release failed with non-terminal" + " InvalidRequestError (%s); re-raising", + log_context, + schedule_id, + message, + ) + raise + logger.warning( + "%s: schedule %s not releasable (%s); treating as already released", + log_context, + schedule_id, + message, + ) + return False + + +async def _schedule_downgrade_at_period_end( + sub: stripe.Subscription, + new_price_id: str, + user_id: str, + tier: SubscriptionTier, +) -> None: + """Create a Subscription Schedule that defers a tier change to period end. + + Stripe's Subscription Schedule drives an existing subscription through a + series of phases. By keeping the current price for the remainder of the + billing period and switching to ``new_price_id`` afterwards, the user does + NOT receive an immediate proration charge and keeps their current tier + until period end. + + Stripe allows at most one active schedule per subscription and rejects + ``SubscriptionSchedule.create`` if either (a) a schedule is already + attached to the subscription or (b) ``cancel_at_period_end=True`` is set. + Both conditions mean the user is overwriting a pending change they made + earlier (e.g. BUSINESS→FREE cancel, now switching to BUSINESS→PRO + downgrade). We clear the conflicting state first so the new schedule can + be created. These defensive reads serialize through Stripe's own atomic + operations — by the time modify/release returns, the subscription is in a + known-clean state for the subsequent create. + """ + sub_id = sub.id + # ``sub["items"]`` (dict-item) rather than ``sub.items`` because the latter + # is shadowed by Python's dict.items() method on StripeObject. + items = sub["items"].data + if not items: + raise ValueError(f"Subscription {sub_id} has no items; cannot schedule") + price = items[0].price + current_price_id = price if isinstance(price, str) else price.id + period_start: int = sub["current_period_start"] + period_end: int = sub["current_period_end"] + + if sub.cancel_at_period_end: + await stripe.Subscription.modify_async(sub_id, cancel_at_period_end=False) + logger.info( + "_schedule_downgrade_at_period_end: cleared cancel_at_period_end" + " on sub %s for user %s before scheduling downgrade", + sub_id, + user_id, + ) + if sub.schedule: + existing_schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + await _release_schedule_ignoring_terminal( + existing_schedule_id, "_schedule_downgrade_at_period_end" + ) + + # Create + modify as a two-step transaction. If modify fails (network, + # Stripe 500) the created schedule is orphaned AND attached to the + # subscription, which blocks any future Stripe-side change until manually + # released. Roll back by releasing the orphan, then re-raise so the caller + # sees the original failure. + schedule = await stripe.SubscriptionSchedule.create_async(from_subscription=sub_id) + try: + await stripe.SubscriptionSchedule.modify_async( + schedule.id, + phases=[ + { + "items": [{"price": current_price_id, "quantity": 1}], + "start_date": period_start, + "end_date": period_end, + "proration_behavior": "none", + }, + { + "items": [{"price": new_price_id, "quantity": 1}], + "proration_behavior": "none", + }, + ], + metadata={"user_id": user_id, "pending_tier": tier.value}, + ) + except stripe.StripeError: + logger.exception( + "_schedule_downgrade_at_period_end: modify failed for schedule %s" + " on sub %s user %s; attempting rollback release", + schedule.id, + sub_id, + user_id, + ) + try: + await _release_schedule_ignoring_terminal( + schedule.id, "_schedule_downgrade_at_period_end_rollback" + ) + except stripe.StripeError: + logger.exception( + "_schedule_downgrade_at_period_end: rollback release also failed" + " for orphaned schedule %s on sub %s user %s; manual cleanup" + " required", + schedule.id, + sub_id, + user_id, + ) + raise + logger.info( + "modify_stripe_subscription_for_tier: scheduled sub %s downgrade for user %s → %s at %d", + sub_id, + user_id, + tier, + period_end, + ) + + async def modify_stripe_subscription_for_tier( user_id: str, tier: SubscriptionTier ) -> bool: - """Modify an existing Stripe subscription to a new paid tier using proration. + """Change a Stripe subscription to a new paid tier. - For paid→paid tier changes (e.g. PRO↔BUSINESS), modifying the existing - subscription is preferable to cancelling + creating a new one via Checkout: - Stripe handles proration automatically, crediting unused time on the old plan - and charging the pro-rated amount for the new plan in the same billing cycle. + Upgrades (e.g. PRO→BUSINESS) apply immediately via ``stripe.Subscription.modify`` + with ``proration_behavior="create_prorations"``: Stripe credits unused time on + the old plan and charges the pro-rated amount for the new plan in the same + billing cycle. + + Downgrades (e.g. BUSINESS→PRO) are deferred to the end of the current billing + period via a Stripe Subscription Schedule: the user keeps their current tier + for the time they already paid for, and the new tier takes effect when the + next invoice is generated. The DB tier flip happens via the webhook fired + when the schedule advances to its next phase. Returns: - True — a subscription was found and modified successfully. + True — a subscription was found and modified/scheduled successfully. False — no active/trialing subscription exists (e.g. admin-granted tier or first-time paid signup); caller should fall back to Checkout. @@ -1437,41 +1666,262 @@ async def modify_stripe_subscription_for_tier( if not price_id: raise ValueError(f"No Stripe price ID configured for tier {tier}") - # Guard: only proceed if the user already has a Stripe customer ID. Calling - # get_stripe_customer_id for a user with no Stripe record (e.g. admin-granted tier) - # would create an orphaned customer object if the subsequent Subscription.list call - # fails. Return False early so the API layer falls back to Checkout instead. user = await get_user_by_id(user_id) if not user.stripe_customer_id: return False + current_tier = user.subscription_tier or SubscriptionTier.FREE - customer_id = user.stripe_customer_id - for status in ("active", "trialing"): - subscriptions = await run_in_threadpool( - stripe.Subscription.list, customer=customer_id, status=status, limit=1 - ) - if not subscriptions.data: - continue - sub = subscriptions.data[0] - sub_id = sub["id"] - items = sub.get("items", {}).get("data", []) - if not items: - continue - item_id = items[0]["id"] - await run_in_threadpool( - stripe.Subscription.modify, - sub_id, - items=[{"id": item_id, "price": price_id}], - proration_behavior="create_prorations", - ) + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return False + items = sub["items"].data + if not items: + return False + sub_id = sub.id + + # Invalidate the cache unconditionally on exit (success OR failure): any + # Stripe mutation below — clearing cancel_at_period_end, releasing an old + # schedule, creating a new one — may have landed partially before an error + # was raised, and the cached pending-change state would otherwise go stale + # for up to 30s until the TTL expires. + try: + if is_tier_downgrade(current_tier, tier): + await _schedule_downgrade_at_period_end(sub, price_id, user_id, tier) + return True + + # Upgrade path. If a schedule is attached from a previous pending + # downgrade, release it first — an upgrade expresses the user's + # intent to be on this tier immediately, which overrides any pending + # deferred change. Ignore terminal-state errors from release. + if sub.schedule: + existing_schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + await _release_schedule_ignoring_terminal( + existing_schedule_id, "modify_stripe_subscription_for_tier" + ) + + # If a paid→FREE cancel is pending (cancel_at_period_end=True), clear it + # as part of the upgrade — the user is explicitly choosing to stay on a + # paid tier. Without this, the sub would be upgraded AND still cancelled + # at period end, leaving a confusing dual state. + modify_kwargs: dict = { + "items": [{"id": items[0].id, "price": price_id}], + "proration_behavior": "create_prorations", + } + if sub.cancel_at_period_end: + modify_kwargs["cancel_at_period_end"] = False + + await stripe.Subscription.modify_async(sub_id, **modify_kwargs) + # Flip the DB tier immediately. The customer.subscription.updated webhook + # will also fire and set it again — idempotent. Without this synchronous + # update, the UI refetches before the webhook lands and shows the old + # tier, making the upgrade look like a no-op to the user. + # + # Swallow DB-write exceptions here: Stripe is authoritative and the + # modify above already succeeded (the user has been charged). If the + # DB write fails and we re-raised, the API would return 5xx and the UI + # would surface a failed upgrade to a user who was already charged. + # The customer.subscription.updated webhook will reconcile the DB shortly. + # + # Only catch actual DB/connection failures — letting KeyError, + # AttributeError etc. propagate so programming errors surface in Sentry + # instead of being silently masked as benign DB-write-swallow events. + try: + await set_subscription_tier(user_id, tier) + except (PrismaError, ConnectionError, asyncio.TimeoutError): + logger.exception( + "modify_stripe_subscription_for_tier: Stripe modify on sub %s" + " succeeded for user %s → %s but DB tier flip failed; webhook" + " will reconcile", + sub_id, + user_id, + tier, + ) logger.info( - "modify_stripe_subscription_for_tier: modified sub %s for user %s → %s", + "modify_stripe_subscription_for_tier: upgraded sub %s for user %s → %s", sub_id, user_id, tier, ) return True - return False + finally: + get_pending_subscription_change.cache_delete(user_id) + + +async def release_pending_subscription_schedule(user_id: str) -> bool: + """Cancel any pending subscription change (scheduled downgrade or cancellation). + + Two pending-change mechanisms can be attached to a Stripe subscription: + + - **Subscription Schedule** (paid→paid downgrade): ``stripe.SubscriptionSchedule.release`` + detaches the schedule and lets the subscription continue on its current + phase's price. + - **cancel_at_period_end=True** (paid→FREE cancel): clearing that flag via + ``stripe.Subscription.modify`` keeps the subscription active indefinitely. + + Returns True if a pending change was found and reverted, False otherwise. + """ + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + return False + + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return False + + sub_id = sub.id + did_anything = False + schedule_released = False + schedule_id: str | None = None + try: + if sub.schedule: + schedule_id = ( + sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + ) + schedule_released = await _release_schedule_ignoring_terminal( + schedule_id, "release_pending_subscription_schedule" + ) + if schedule_released: + logger.info( + "release_pending_subscription_schedule: released schedule %s for user %s", + schedule_id, + user_id, + ) + did_anything = True + if sub.cancel_at_period_end: + try: + await stripe.Subscription.modify_async( + sub_id, cancel_at_period_end=False + ) + except stripe.StripeError: + if schedule_released: + logger.exception( + "release_pending_subscription_schedule: partial release" + " — schedule %s released but cancel_at_period_end clear" + " failed on sub %s for user %s; manual reconciliation" + " may be needed", + schedule_id, + sub_id, + user_id, + ) + raise + did_anything = True + logger.info( + "release_pending_subscription_schedule: cleared cancel_at_period_end" + " on sub %s for user %s", + sub_id, + user_id, + ) + finally: + if did_anything: + get_pending_subscription_change.cache_delete(user_id) + return did_anything + + +@cached(ttl_seconds=30, maxsize=512, cache_none=True, shared_cache=True) +async def get_pending_subscription_change( + user_id: str, +) -> tuple[SubscriptionTier, datetime] | None: + """Return ``(pending_tier, effective_at)`` when a change is queued, else ``None``. + + Reflects both Subscription Schedule phase transitions (paid→paid downgrade) + and ``cancel_at_period_end=True`` (paid→FREE cancel). + + Cached for 30 seconds per user_id. *Why the cache exists:* this function + runs on every dashboard/home fetch and would otherwise fire + 2× Subscription.list + 1× Schedule.retrieve per page load. A busy user + polling the billing page would quickly brush up against Stripe's per-API + rate limits; the 30s TTL absorbs dashboard polling while being short + enough that the UI reconciles quickly after a downgrade / cancel action. + + *Invalidation contract.* Every call-site that mutates Stripe state which + could change the pending-change answer MUST call + ``get_pending_subscription_change.cache_delete(user_id)`` so the UI never + shows a stale pending badge after a user-visible action. Current + invalidators (keep this list in sync when adding new mutators): + + - ``set_subscription_tier`` — admin or webhook-driven tier flip. + - ``modify_stripe_subscription_for_tier`` — ``finally`` block (covers + upgrade path clear + downgrade-schedule create + any partial failure). + - ``release_pending_subscription_schedule`` — ``finally`` block when a + schedule release OR ``cancel_at_period_end`` clear succeeded. + - ``cancel_stripe_subscription`` — after scheduling period-end cancel. + - ``sync_subscription_from_stripe`` — webhook entry point. + - ``set_user_tier`` (``backend.copilot.rate_limit``) — admin tier override + invalidates any cached pending state keyed off the old tier. + """ + user = await get_user_by_id(user_id) + if not user.stripe_customer_id: + # Short-circuit for users with no Stripe customer (admin-granted tiers, + # FREE-only users): skip the Stripe API calls entirely. + return None + + pro_price, biz_price = await asyncio.gather( + get_subscription_price_id(SubscriptionTier.PRO), + get_subscription_price_id(SubscriptionTier.BUSINESS), + ) + price_to_tier: dict[str, SubscriptionTier] = {} + if pro_price: + price_to_tier[pro_price] = SubscriptionTier.PRO + if biz_price: + price_to_tier[biz_price] = SubscriptionTier.BUSINESS + if not price_to_tier: + logger.warning( + "get_pending_subscription_change: no Stripe price IDs resolvable for" + " PRO/BUSINESS (LaunchDarkly fetch failed?); raising to bypass the" + " None cache so the next request retries fresh" + ) + raise PendingChangeUnknown( + "Stripe price lookup failed; pending-change state cannot be determined" + ) + + sub = await _get_active_subscription(user.stripe_customer_id) + if sub is None: + return None + period_end = sub.current_period_end + if not isinstance(period_end, int): + return None + effective_at = datetime.fromtimestamp(period_end, tz=timezone.utc) + if sub.cancel_at_period_end: + return SubscriptionTier.FREE, effective_at + if not sub.schedule: + return None + schedule_id = sub.schedule if isinstance(sub.schedule, str) else sub.schedule.id + schedule = await stripe.SubscriptionSchedule.retrieve_async(schedule_id) + return _next_phase_tier_and_start(schedule, price_to_tier) + + +def _next_phase_tier_and_start( + schedule: stripe.SubscriptionSchedule, + price_to_tier: dict[str, SubscriptionTier], +) -> tuple[SubscriptionTier, datetime] | None: + """Return (tier, start_datetime) of the phase that follows the active one. + + Using the phase's own ``start_date`` (not the subscription's current_period_end) + is correct even for schedules created outside this flow — a dashboard-authored + schedule can have phase transitions at arbitrary timestamps. + """ + now = int(time.time()) + for phase in schedule.phases or []: + if not isinstance(phase.start_date, int) or phase.start_date <= now: + continue + # ``phase["items"]`` because ``phase.items`` is shadowed by dict.items(). + items = phase["items"] or [] + if not items: + continue + price = items[0].price + price_id = price if isinstance(price, str) else price.id + if price_id in price_to_tier: + return price_to_tier[price_id], datetime.fromtimestamp( + phase.start_date, tz=timezone.utc + ) + logger.warning( + "next_phase_tier_and_start: unknown price %s on schedule %s", + price_id, + schedule.id, + ) + return None async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: @@ -1732,6 +2182,50 @@ async def sync_subscription_from_stripe(stripe_subscription: dict) -> None: # cancel the old sub. await _cleanup_stale_subscriptions(customer_id, new_sub_id) await set_subscription_tier(user.id, tier) + # Tier changed — bust any cached pending-change view so the next + # dashboard fetch reflects the new state immediately. + get_pending_subscription_change.cache_delete(user.id) + + +async def sync_subscription_schedule_from_stripe(stripe_schedule: dict) -> None: + """Sync the DB tier from a ``subscription_schedule.*`` webhook event. + + Stripe fires ``subscription_schedule.released`` / ``.completed`` / + ``.updated`` when a schedule advances phases or is detached. The regular + ``customer.subscription.updated`` webhook with the new price covers the + phase transition in most cases, but listening to schedule events is a + safety net that also catches releases done via the Stripe dashboard. + + The schedule payload doesn't carry the active price directly — it carries + a ``subscription`` id that we look up to get the current item. + + Webhook-ordering safety: we deliberately funnel both event sources through + ``sync_subscription_from_stripe`` so they share one code path and one DB + write. That function is idempotent — it no-ops when ``current_tier == + tier`` — so concurrent or out-of-order deliveries of + ``subscription_schedule.*`` and ``customer.subscription.updated`` converge + to the same DB state regardless of which arrives first. + """ + # When a schedule is released, Stripe clears `subscription` and moves the id + # to `released_subscription`. Fall back to that so `.released` events — the + # main reason we listen to schedule webhooks as a safety net — are processed. + sub_id = stripe_schedule.get("subscription") or stripe_schedule.get( + "released_subscription" + ) + if not isinstance(sub_id, str) or not sub_id: + logger.warning( + "sync_subscription_schedule_from_stripe: no 'subscription' id; skipping" + ) + return + try: + sub = await stripe.Subscription.retrieve_async(sub_id) + except stripe.StripeError: + logger.warning( + "sync_subscription_schedule_from_stripe: failed to retrieve sub %s", + sub_id, + ) + return + await sync_subscription_from_stripe(dict(sub)) async def handle_subscription_payment_failure(invoice: dict) -> None: diff --git a/autogpt_platform/backend/backend/data/credit_subscription_test.py b/autogpt_platform/backend/backend/data/credit_subscription_test.py index a9634afcb4..d38f71d09e 100644 --- a/autogpt_platform/backend/backend/data/credit_subscription_test.py +++ b/autogpt_platform/backend/backend/data/credit_subscription_test.py @@ -12,11 +12,16 @@ from prisma.models import User from backend.data.credit import ( cancel_stripe_subscription, create_subscription_checkout, + get_pending_subscription_change, get_proration_credit_cents, handle_subscription_payment_failure, + is_tier_downgrade, + is_tier_upgrade, modify_stripe_subscription_for_tier, + release_pending_subscription_schedule, set_subscription_tier, sync_subscription_from_stripe, + sync_subscription_schedule_from_stripe, ) @@ -310,7 +315,11 @@ def _make_user_with_stripe(stripe_customer_id: str | None = "cus_123") -> MagicM @pytest.mark.asyncio async def test_cancel_stripe_subscription_cancels_active(): mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": None}, "sk_test" + ) + ] mock_subscriptions.has_more = False with ( @@ -346,7 +355,14 @@ async def test_cancel_stripe_subscription_no_customer_id_returns_false(): async def test_cancel_stripe_subscription_multi_partial_failure(): """First modify raises → error propagates and subsequent subs are not scheduled.""" mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_first"}, {"id": "sub_second"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_first", "schedule": None}, "sk_test" + ), + stripe.Subscription.construct_from( + {"id": "sub_second", "schedule": None}, "sk_test" + ), + ] mock_subscriptions.has_more = False with ( @@ -428,7 +444,11 @@ async def test_cancel_stripe_subscription_cancels_trialing(): active_subs.data = [] active_subs.has_more = False trialing_subs = MagicMock() - trialing_subs.data = [{"id": "sub_trial_123"}] + trialing_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_trial_123", "schedule": None}, "sk_test" + ) + ] trialing_subs.has_more = False def list_side_effect(*args, **kwargs): @@ -454,10 +474,18 @@ async def test_cancel_stripe_subscription_cancels_trialing(): async def test_cancel_stripe_subscription_cancels_active_and_trialing(): """Both active AND trialing subs present → both get scheduled for cancellation, no duplicates.""" active_subs = MagicMock() - active_subs.data = [{"id": "sub_active_1"}] + active_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_active_1", "schedule": None}, "sk_test" + ) + ] active_subs.has_more = False trialing_subs = MagicMock() - trialing_subs.data = [{"id": "sub_trial_2"}] + trialing_subs.data = [ + stripe.Subscription.construct_from( + {"id": "sub_trial_2", "schedule": None}, "sk_test" + ) + ] trialing_subs.has_more = False def list_side_effect(*args, **kwargs): @@ -480,6 +508,62 @@ async def test_cancel_stripe_subscription_cancels_active_and_trialing(): assert modified_ids == {"sub_active_1", "sub_trial_2"} +@pytest.mark.asyncio +async def test_cancel_stripe_subscription_releases_attached_schedule_first(): + """Pre-existing Subscription Schedule must be released before cancel_at_period_end. + + Stripe rejects ``modify(cancel_at_period_end=True)`` with HTTP 400 when the + subscription has an attached schedule (e.g. user queued a BUSINESS→PRO + downgrade and now clicks "Downgrade to FREE"). Without the pre-release, + the API handler would surface a 502 to the user. + """ + mock_subscriptions = MagicMock() + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": "sub_sched_abc"}, "sk_test" + ) + ] + mock_subscriptions.has_more = False + + call_order: list[str] = [] + + async def record_release(schedule_id): + call_order.append(f"release:{schedule_id}") + + def record_modify(sub_id, **kwargs): + call_order.append(f"modify:{sub_id}:{kwargs}") + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=_make_user_with_stripe("cus_123"), + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=mock_subscriptions, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=record_release, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify", + side_effect=record_modify, + ) as mock_modify, + ): + await cancel_stripe_subscription("user-1") + + mock_release.assert_awaited_once_with("sub_sched_abc") + mock_modify.assert_called_once_with("sub_abc123", cancel_at_period_end=True) + # Release must happen before modify, else Stripe returns 400. + assert call_order == [ + "release:sub_sched_abc", + "modify:sub_abc123:{'cancel_at_period_end': True}", + ] + + @pytest.mark.asyncio async def test_get_proration_credit_cents_no_stripe_customer_returns_zero(): """Admin-granted tier users without stripe_customer_id get 0 without creating a customer.""" @@ -878,7 +962,11 @@ async def test_cancel_stripe_subscription_raises_on_cancel_error(): import stripe as stripe_mod mock_subscriptions = MagicMock() - mock_subscriptions.data = [{"id": "sub_abc123"}] + mock_subscriptions.data = [ + stripe.Subscription.construct_from( + {"id": "sub_abc123", "schedule": None}, "sk_test" + ) + ] mock_subscriptions.has_more = False with ( @@ -1099,15 +1187,21 @@ async def test_handle_subscription_payment_failure_passes_invoice_id_as_transact @pytest.mark.asyncio async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): """modify_stripe_subscription_for_tier calls Subscription.modify and returns True.""" - mock_sub = { - "id": "sub_abc", - "items": {"data": [{"id": "si_abc"}]}, - } + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_abc", + "items": {"data": [{"id": "si_abc"}]}, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) mock_list = MagicMock() mock_list.data = [mock_sub] mock_user = MagicMock(spec=User) mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.FREE with ( patch( @@ -1121,12 +1215,18 @@ async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): return_value=mock_user, ), patch( - "backend.data.credit.stripe.Subscription.list", + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, return_value=mock_list, ), patch( - "backend.data.credit.stripe.Subscription.modify", + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, ) as mock_modify, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ) as mock_set_tier, ): result = await modify_stripe_subscription_for_tier( "user-1", SubscriptionTier.PRO @@ -1138,6 +1238,66 @@ async def test_modify_stripe_subscription_for_tier_modifies_existing_sub(): items=[{"id": "si_abc", "price": "price_pro_monthly"}], proration_behavior="create_prorations", ) + mock_set_tier.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_clears_cancel_at_period_end_on_upgrade(): + """Upgrading from a sub with cancel_at_period_end=True clears the flag so the + upgrade isn't silently cancelled at period end and the DB tier flips immediately.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_upgrading", + "items": {"data": [{"id": "si_abc"}]}, + "schedule": None, + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ) as mock_set_tier, + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + mock_modify.assert_called_once_with( + "sub_upgrading", + items=[{"id": "si_abc", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + cancel_at_period_end=False, + ) + mock_set_tier.assert_awaited_once_with("user-1", SubscriptionTier.BUSINESS) @pytest.mark.asyncio @@ -1178,6 +1338,7 @@ async def test_modify_stripe_subscription_for_tier_returns_false_when_no_sub(): mock_user = MagicMock(spec=User) mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.FREE with ( patch( @@ -1191,7 +1352,8 @@ async def test_modify_stripe_subscription_for_tier_returns_false_when_no_sub(): return_value=mock_user, ), patch( - "backend.data.credit.stripe.Subscription.list", + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, return_value=mock_list, ), ): @@ -1212,3 +1374,1089 @@ async def test_modify_stripe_subscription_for_tier_raises_on_missing_price_id(): ): with pytest.raises(ValueError, match="No Stripe price ID configured"): await modify_stripe_subscription_for_tier("user-1", SubscriptionTier.PRO) + + +def test_tier_order_helpers(): + assert is_tier_upgrade(SubscriptionTier.FREE, SubscriptionTier.PRO) is True + assert is_tier_upgrade(SubscriptionTier.PRO, SubscriptionTier.BUSINESS) is True + assert is_tier_upgrade(SubscriptionTier.BUSINESS, SubscriptionTier.PRO) is False + assert is_tier_downgrade(SubscriptionTier.BUSINESS, SubscriptionTier.PRO) is True + assert is_tier_downgrade(SubscriptionTier.PRO, SubscriptionTier.FREE) is True + assert is_tier_downgrade(SubscriptionTier.PRO, SubscriptionTier.BUSINESS) is False + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_downgrade_creates_schedule(): + """Paid→paid downgrade (BUSINESS→PRO) creates a Subscription Schedule rather than proration.""" + import time as time_mod + + now = int(time_mod.time()) + period_end = now + 27 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": period_end, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_1"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_schedule, + ) as mock_schedule_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ) as mock_schedule_modify, + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # Did NOT call Subscription.modify with proration (no immediate tier change). + mock_modify.assert_not_called() + mock_schedule_create.assert_called_once_with(from_subscription="sub_biz") + assert mock_schedule_modify.call_count == 1 + _, kwargs = mock_schedule_modify.call_args + phases = kwargs["phases"] + assert phases[0]["items"][0]["price"] == "price_biz_monthly" + assert phases[0]["end_date"] == period_end + assert phases[1]["items"][0]["price"] == "price_pro_monthly" + assert phases[0]["proration_behavior"] == "none" + assert phases[1]["proration_behavior"] == "none" + + +@pytest.mark.asyncio +async def test_modify_stripe_subscription_for_tier_upgrade_immediate_proration(): + """PRO→BUSINESS upgrade still uses Subscription.modify with proration (no schedule).""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "items": {"data": [{"id": "si_pro", "price": {"id": "price_pro_monthly"}}]}, + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + ) as mock_schedule_create, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + mock_modify.assert_called_once_with( + "sub_pro", + items=[{"id": "si_pro", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + ) + mock_schedule_create.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_releases_downgrade_schedule(): + """release_pending_subscription_schedule releases the Stripe schedule if one is attached.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_1", + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is True + mock_release.assert_called_once_with("sub_sched_1") + mock_modify.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_clears_cancel_at_period_end(): + """release_pending_subscription_schedule reverts a pending paid→FREE cancel.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": None, + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is True + mock_modify.assert_called_once_with("sub_pro", cancel_at_period_end=False) + mock_release.assert_not_called() + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_no_pending_change_returns_false(): + """release_pending_subscription_schedule returns False when no schedule/cancel is set.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": None, + "cancel_at_period_end": False, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is False + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_no_stripe_customer_returns_false(): + mock_user = MagicMock() + mock_user.stripe_customer_id = None + + with patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ): + result = await release_pending_subscription_schedule("user-1") + + assert result is False + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_cancel_at_period_end(): + """cancel_at_period_end=True maps to pending FREE at current_period_end.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + period_end = now + 10 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "current_period_end": period_end, + "cancel_at_period_end": True, + "schedule": None, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is not None + pending_tier, effective_at = result + assert pending_tier == SubscriptionTier.FREE + assert int(effective_at.timestamp()) == period_end + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_from_schedule(): + """A schedule whose next phase uses the PRO price maps to pending_tier=PRO.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + period_end = now + 10 * 24 * 3600 + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "current_period_end": period_end, + "cancel_at_period_end": False, + "schedule": "sub_sched_1", + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_schedule = stripe.SubscriptionSchedule.construct_from( + { + "id": "sub_sched_1", + "phases": [ + { + "start_date": now - 3 * 24 * 3600, + "end_date": period_end, + "items": [{"price": "price_biz_monthly"}], + }, + { + "start_date": period_end, + "items": [{"price": "price_pro_monthly"}], + }, + ], + }, + "k", + ) + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.retrieve_async", + new_callable=AsyncMock, + return_value=mock_schedule, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is not None + pending_tier, effective_at = result + assert pending_tier == SubscriptionTier.PRO + assert int(effective_at.timestamp()) == period_end + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_none_when_no_schedule_or_cancel(): + """Returns None when neither a schedule nor cancel_at_period_end is set.""" + import time as time_mod + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "current_period_end": now + 10 * 24 * 3600, + "cancel_at_period_end": False, + "schedule": None, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return { + SubscriptionTier.PRO: "price_pro", + SubscriptionTier.BUSINESS: "price_biz", + }.get(tier) + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + ): + result = await get_pending_subscription_change("user-1") + + assert result is None + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_retrieves_and_delegates(): + """subscription_schedule.released triggers a sync via the active subscription object.""" + stripe_schedule = {"id": "sub_sched_1", "subscription": "sub_pro"} + retrieved_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "customer": "cus_abc", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + + with ( + patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + return_value=retrieved_sub, + ) as mock_retrieve, + patch( + "backend.data.credit.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) as mock_sync, + ): + await sync_subscription_schedule_from_stripe(stripe_schedule) + + mock_retrieve.assert_called_once_with("sub_pro") + mock_sync.assert_awaited_once() + forwarded = mock_sync.call_args.args[0] + assert forwarded["id"] == "sub_pro" + assert forwarded["customer"] == "cus_abc" + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_uses_released_subscription_fallback(): + """subscription_schedule.released events clear `subscription` and set + `released_subscription`; the sync handler must fall back to that id.""" + stripe_schedule = { + "id": "sub_sched_1", + "subscription": None, + "released_subscription": "sub_pro_released", + } + retrieved_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro_released", + "customer": "cus_abc", + "status": "active", + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + + with ( + patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + return_value=retrieved_sub, + ) as mock_retrieve, + patch( + "backend.data.credit.sync_subscription_from_stripe", + new_callable=AsyncMock, + ) as mock_sync, + ): + await sync_subscription_schedule_from_stripe(stripe_schedule) + + mock_retrieve.assert_called_once_with("sub_pro_released") + mock_sync.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_sync_subscription_schedule_from_stripe_missing_sub_id_returns(): + """A schedule event with no 'subscription' field is logged and ignored.""" + with patch( + "backend.data.credit.stripe.Subscription.retrieve_async", + new_callable=AsyncMock, + ) as mock_retrieve: + await sync_subscription_schedule_from_stripe({"id": "sub_sched_1"}) + mock_retrieve.assert_not_called() + + +@pytest.mark.asyncio +async def test_sync_subscription_from_stripe_phase_transition_updates_tier(): + """When a schedule advances phases, Stripe fires customer.subscription.updated with + the new price — the existing sync handler must update the DB tier accordingly.""" + mock_user = _make_user(tier=SubscriptionTier.BUSINESS) + stripe_sub = { + "id": "sub_pro", + "customer": "cus_abc", + "status": "active", + # Phase advanced: price is now PRO (was BUSINESS before). + "items": {"data": [{"price": {"id": "price_pro_monthly"}}]}, + } + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + if tier == SubscriptionTier.PRO: + return "price_pro_monthly" + if tier == SubscriptionTier.BUSINESS: + return "price_biz_monthly" + return None + + empty_list = MagicMock() + empty_list.data = [] + empty_list.has_more = False + + with ( + patch( + "backend.data.credit.User.prisma", + return_value=MagicMock(find_first=AsyncMock(return_value=mock_user)), + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + patch( + "backend.data.credit.stripe.Subscription.list", + return_value=empty_list, + ), + patch( + "backend.data.credit.set_subscription_tier", new_callable=AsyncMock + ) as mock_set, + ): + await sync_subscription_from_stripe(stripe_sub) + mock_set.assert_awaited_once_with("user-1", SubscriptionTier.PRO) + + +@pytest.mark.asyncio +async def test_release_schedule_idempotent_on_terminal_state(): + """SubscriptionSchedule.release raising InvalidRequestError on a terminal-state + schedule is treated as success; we still continue to the cancel_at_period_end clear. + """ + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_terminal", + "cancel_at_period_end": True, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "Schedule has already been released", + param="schedule", + ), + ) as mock_release, + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + ): + result = await release_pending_subscription_schedule("user-1") + + # Terminal-state release is treated as idempotent success; modify still runs. + assert result is True + mock_release.assert_called_once_with("sub_sched_terminal") + mock_modify.assert_called_once_with("sub_biz", cancel_at_period_end=False) + + +@pytest.mark.asyncio +async def test_schedule_downgrade_releases_existing_schedule(): + """_schedule_downgrade_at_period_end releases any pre-existing schedule first.""" + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": "sub_sched_old", + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # Existing schedule released before creating the new one. + mock_release.assert_called_once_with("sub_sched_old") + mock_create.assert_called_once_with(from_subscription="sub_biz") + # cancel_at_period_end was False, so Subscription.modify should not be called. + mock_modify.assert_not_called() + + +@pytest.mark.asyncio +async def test_schedule_downgrade_clears_cancel_at_period_end(): + """_schedule_downgrade_at_period_end clears cancel_at_period_end before scheduling.""" + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": None, + "cancel_at_period_end": True, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.PRO + ) + + assert result is True + # cancel_at_period_end cleared before new schedule is created. + mock_modify.assert_called_once_with("sub_biz", cancel_at_period_end=False) + mock_create.assert_called_once_with(from_subscription="sub_biz") + + +@pytest.mark.asyncio +async def test_schedule_downgrade_rolls_back_orphan_on_modify_failure(): + """If SubscriptionSchedule.modify fails after a successful create, the + orphaned schedule must be released so it doesn't stay attached and block + future changes. The original StripeError re-raises to the caller. + """ + import time as time_mod + + now = int(time_mod.time()) + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_biz", + "schedule": None, + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_biz", "price": {"id": "price_biz_monthly"}}]}, + "current_period_start": now - 3 * 24 * 3600, + "current_period_end": now + 27 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.BUSINESS + + mock_new_schedule = stripe.SubscriptionSchedule.construct_from( + {"id": "sub_sched_new"}, "k" + ) + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_pro_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.create_async", + new_callable=AsyncMock, + return_value=mock_new_schedule, + ) as mock_create, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.modify_async", + new_callable=AsyncMock, + side_effect=stripe.APIConnectionError("network down"), + ) as mock_schedule_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + ): + with pytest.raises(stripe.APIConnectionError): + await modify_stripe_subscription_for_tier("user-1", SubscriptionTier.PRO) + + mock_create.assert_called_once_with(from_subscription="sub_biz") + mock_schedule_modify.assert_called_once() + # Rollback must release the freshly-created (and now orphaned) schedule + # id, not the pre-existing one (there was none here). + mock_release.assert_called_once_with("sub_sched_new") + + +@pytest.mark.asyncio +async def test_release_ignoring_terminal_reraises_non_terminal_error(): + """_release_schedule_ignoring_terminal only swallows terminal-state errors. + Typos / wrong ids / 404s surface so bugs aren't silently masked. + """ + from backend.data.credit import _release_schedule_ignoring_terminal + + with patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "No such subscription_schedule: 'sub_sched_typo'", + param="schedule", + ), + ): + with pytest.raises(stripe.InvalidRequestError): + await _release_schedule_ignoring_terminal("sub_sched_typo", "test_context") + + +@pytest.mark.asyncio +async def test_release_ignoring_terminal_swallows_terminal_error(): + """Terminal-state messages are treated as idempotent success and return False.""" + from backend.data.credit import _release_schedule_ignoring_terminal + + with patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + side_effect=stripe.InvalidRequestError( + "Schedule has already been released", + param="schedule", + ), + ): + result = await _release_schedule_ignoring_terminal( + "sub_sched_done", "test_context" + ) + + assert result is False + + +@pytest.mark.asyncio +async def test_upgrade_releases_pending_schedule(): + """modify_stripe_subscription_for_tier upgrade path releases attached schedule first.""" + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_pro", + "schedule": "sub_sched_pending_downgrade", + "cancel_at_period_end": False, + "items": {"data": [{"id": "si_pro", "price": {"id": "price_pro_monthly"}}]}, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + mock_user = MagicMock(spec=User) + mock_user.stripe_customer_id = "cus_abc" + mock_user.subscription_tier = SubscriptionTier.PRO + + with ( + patch( + "backend.data.credit.get_subscription_price_id", + new_callable=AsyncMock, + return_value="price_biz_monthly", + ), + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + ) as mock_modify, + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + ) as mock_release, + patch( + "backend.data.credit.set_subscription_tier", + new_callable=AsyncMock, + ), + ): + result = await modify_stripe_subscription_for_tier( + "user-1", SubscriptionTier.BUSINESS + ) + + assert result is True + # Pending schedule released before the upgrade modify call. + mock_release.assert_called_once_with("sub_sched_pending_downgrade") + mock_modify.assert_called_once_with( + "sub_pro", + items=[{"id": "si_pro", "price": "price_biz_monthly"}], + proration_behavior="create_prorations", + ) + + +@pytest.mark.asyncio +async def test_next_phase_tier_and_start_logs_unknown_price(caplog): + """_next_phase_tier_and_start emits a warning when the next-phase price is unmapped.""" + import logging + import time as time_mod + + from backend.data.credit import _next_phase_tier_and_start + + now = int(time_mod.time()) + schedule = stripe.SubscriptionSchedule.construct_from( + { + "id": "sub_sched_unknown", + "phases": [ + { + "start_date": now - 3 * 24 * 3600, + "end_date": now + 27 * 24 * 3600, + "items": [{"price": "price_current"}], + }, + { + "start_date": now + 27 * 24 * 3600, + "items": [{"price": "price_unknown"}], + }, + ], + }, + "k", + ) + price_to_tier = {"price_pro_monthly": SubscriptionTier.PRO} + + with caplog.at_level(logging.WARNING, logger="backend.data.credit"): + result = _next_phase_tier_and_start(schedule, price_to_tier) + + assert result is None + assert any( + "next_phase_tier_and_start: unknown price price_unknown" in record.message + and "sub_sched_unknown" in record.message + for record in caplog.records + ) + + +@pytest.mark.asyncio +async def test_get_pending_subscription_change_raises_when_price_lookups_fail(): + """When both LD price lookups return None, raise PendingChangeUnknown so the + @cached wrapper doesn't store None and hide pending changes for 30s.""" + from backend.data.credit import PendingChangeUnknown + + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + async def mock_price_id(tier: SubscriptionTier) -> str | None: + return None + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.get_subscription_price_id", + side_effect=mock_price_id, + ), + pytest.raises(PendingChangeUnknown), + ): + await get_pending_subscription_change("user-price-fail") + + +@pytest.mark.asyncio +async def test_release_pending_subscription_schedule_invalidates_cache_on_partial_failure(): + """If schedule.release succeeds but cancel_at_period_end clear fails, the + cache must still be invalidated — otherwise the UI shows a stale pending + banner for up to 30s even though the schedule was actually released.""" + get_pending_subscription_change.cache_clear() # type: ignore[attr-defined] + + mock_user = MagicMock() + mock_user.stripe_customer_id = "cus_abc" + + import time as time_mod + + mock_sub = stripe.Subscription.construct_from( + { + "id": "sub_mixed", + "schedule": "sub_sched_to_release", + "cancel_at_period_end": True, + "current_period_end": int(time_mod.time()) + 10 * 24 * 3600, + }, + "k", + ) + mock_list = MagicMock() + mock_list.data = [mock_sub] + + with ( + patch( + "backend.data.credit.get_user_by_id", + new_callable=AsyncMock, + return_value=mock_user, + ), + patch( + "backend.data.credit.stripe.Subscription.list_async", + new_callable=AsyncMock, + return_value=mock_list, + ), + patch( + "backend.data.credit.stripe.SubscriptionSchedule.release_async", + new_callable=AsyncMock, + return_value=MagicMock(), + ), + patch( + "backend.data.credit.stripe.Subscription.modify_async", + new_callable=AsyncMock, + side_effect=stripe.APIConnectionError("transient Stripe error"), + ), + patch.object( + get_pending_subscription_change, "cache_delete" + ) as mock_cache_delete, + ): + with pytest.raises(stripe.APIConnectionError): + await release_pending_subscription_schedule("user-partial") + + mock_cache_delete.assert_called_once_with("user-partial") diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx index 58a4b9d58b..d8aab67b22 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/SubscriptionTierSection.tsx @@ -4,42 +4,14 @@ import { Button } from "@/components/ui/button"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; import { Skeleton } from "@/components/atoms/Skeleton/Skeleton"; import { useSubscriptionTierSection } from "./useSubscriptionTierSection"; - -type TierInfo = { - key: string; - label: string; - multiplier: string; - description: string; -}; - -const TIERS: TierInfo[] = [ - { - key: "FREE", - label: "Free", - multiplier: "1x", - description: "Base AutoPilot capacity with standard rate limits", - }, - { - key: "PRO", - label: "Pro", - multiplier: "5x", - description: "5x AutoPilot capacity — run 5× more tasks per day/week", - }, - { - key: "BUSINESS", - label: "Business", - multiplier: "20x", - description: "20x AutoPilot capacity — ideal for teams and heavy workloads", - }, -]; - -const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; - -function formatCost(cents: number, tierKey: string): string { - if (tierKey === "FREE") return "Free"; - if (cents === 0) return "Pricing available soon"; - return `$${(cents / 100).toFixed(2)}/mo`; -} +import { PendingChangeBanner } from "./components/PendingChangeBanner/PendingChangeBanner"; +import { + TIERS, + TIER_ORDER, + formatCost, + formatPendingDate, + getTierLabel, +} from "./helpers"; export function SubscriptionTierSection() { const { @@ -55,10 +27,14 @@ export function SubscriptionTierSection() { isPaymentEnabled, changeTier, handleTierChange, + cancelPendingChange, } = useSubscriptionTierSection(); const [confirmDowngradeTo, setConfirmDowngradeTo] = useState( null, ); + const [confirmReplacePendingTo, setConfirmReplacePendingTo] = useState< + string | null + >(null); if (isLoading) { return ( @@ -115,6 +91,34 @@ export function SubscriptionTierSection() { await changeTier(tier); } + async function confirmReplacePending() { + if (!confirmReplacePendingTo) return; + const tier = confirmReplacePendingTo; + setConfirmReplacePendingTo(null); + handleTierChange(tier, currentTier, setConfirmDowngradeTo); + } + + const pendingTierFromSubscription = subscription.pending_tier ?? null; + const hasPendingChange = + pendingTierFromSubscription !== null && + pendingTierFromSubscription !== currentTier; + + function onTierButtonClick(targetTierKey: string) { + // If a pending change is queued and the user clicks a DIFFERENT non-current, + // non-pending tier, surface a confirmation so they don't silently overwrite + // their own scheduled change. The on-card button for the pending tier itself + // is already disabled; the primary cancel path is the banner. + if ( + hasPendingChange && + targetTierKey !== pendingTierFromSubscription && + targetTierKey !== currentTier + ) { + setConfirmReplacePendingTo(targetTierKey); + return; + } + handleTierChange(targetTierKey, currentTier, setConfirmDowngradeTo); + } + return (

Subscription Plan

@@ -128,6 +132,16 @@ export function SubscriptionTierSection() {

)} + {hasPendingChange && pendingTierFromSubscription ? ( + void cancelPendingChange()} + isBusy={isPending} + /> + ) : null} +
{TIERS.map((tier) => { const isCurrent = currentTier === tier.key; @@ -137,6 +151,8 @@ export function SubscriptionTierSection() { const isUpgrade = targetIdx > currentIdx; const isDowngrade = targetIdx < currentIdx; const isThisPending = pendingTier === tier.key; + const isScheduledTier = + hasPendingChange && pendingTierFromSubscription === tier.key; return (
- handleTierChange( - tier.key, - currentTier, - setConfirmDowngradeTo, - ) - } + disabled={isPending || isScheduledTier} + onClick={() => onTierButtonClick(tier.key)} > {isThisPending ? "Updating..." - : isUpgrade - ? `Upgrade to ${tier.label}` - : isDowngrade - ? `Downgrade to ${tier.label}` - : `Switch to ${tier.label}`} + : isScheduledTier + ? "Scheduled" + : isUpgrade + ? `Upgrade to ${tier.label}` + : isDowngrade + ? `Downgrade to ${tier.label}` + : `Switch to ${tier.label}`} )}
@@ -196,9 +208,9 @@ export function SubscriptionTierSection() { {currentTier !== "FREE" && isPaymentEnabled && (

- Your subscription is managed through Stripe. Upgrades and paid-tier - changes take effect immediately; downgrades to Free are scheduled for - the end of the current billing period. + Your subscription is managed through Stripe. Upgrades take effect + immediately. Downgrades take effect at the end of your current billing + period.

)} @@ -215,7 +227,7 @@ export function SubscriptionTierSection() {

{confirmDowngradeTo === "FREE" ? "Downgrading to Free will schedule your subscription to cancel at the end of your current billing period. You keep your current plan until then." - : `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect immediately.`}{" "} + : `Switching to ${TIERS.find((t) => t.key === confirmDowngradeTo)?.label ?? confirmDowngradeTo} will take effect at the end of your current billing period. You keep your current plan until then.`}{" "} Are you sure?

@@ -235,6 +247,42 @@ export function SubscriptionTierSection() { + { + if (!open) setConfirmReplacePendingTo(null); + }, + }} + > + +

+ You have a pending change to{" "} + {getTierLabel(pendingTierFromSubscription ?? "")} + {subscription.pending_tier_effective_at + ? ` scheduled for ${formatPendingDate(subscription.pending_tier_effective_at)}` + : ""} + . Switching to {getTierLabel(confirmReplacePendingTo ?? "")} will + replace it. Continue? +

+ + + + +
+
+ ; prorationCreditCents?: number; + pendingTier?: string | null; + pendingTierEffectiveAt?: Date | string | null; } = {}) { return { tier, monthly_cost: monthlyCost, tier_costs: tierCosts, proration_credit_cents: prorationCreditCents, + pending_tier: pendingTier, + pending_tier_effective_at: pendingTierEffectiveAt, }; } @@ -92,6 +98,7 @@ function setupMocks({ mutateFn = vi.fn().mockResolvedValue({ status: 200, data: { url: "" } }), isPending = false, variables = undefined as { data?: { tier?: string } } | undefined, + refetchFn = vi.fn(), } = {}) { // The hook uses select: (data) => (data.status === 200 ? data.data : null) // so the data value returned by the hook is already the transformed subscription object. @@ -100,13 +107,14 @@ function setupMocks({ data: subscription, isLoading, error: queryError, - refetch: vi.fn(), + refetch: refetchFn, }); mockUseUpdateSubscriptionTier.mockReturnValue({ mutateAsync: mutateFn, isPending, variables, }); + return { refetchFn, mutateFn }; } afterEach(() => { @@ -355,4 +363,229 @@ describe("SubscriptionTierSection", () => { // No toast should fire — the user simply abandoned checkout expect(mockToast).not.toHaveBeenCalled(); }); + + it("renders pending-change banner when pending_tier is set", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + }); + render(); + expect(screen.getByText(/scheduled to downgrade to/i)).toBeDefined(); + // Banner "Keep Business" button — the only Keep button, since the on-card + // duplicate was removed in favour of the banner. + expect( + screen.getAllByRole("button", { name: /keep business/i }), + ).toHaveLength(1); + }); + + it("does not render pending-change banner when pending_tier is null", () => { + setupMocks({ + subscription: makeSubscription({ tier: "BUSINESS", pendingTier: null }), + }); + render(); + expect(screen.queryByText(/scheduled to downgrade/i)).toBeNull(); + expect(screen.queryByRole("button", { name: /keep business/i })).toBeNull(); + }); + + it("clicking Keep [CurrentTier] in banner submits a same-tier update and refetches", async () => { + // The cancel-pending route was collapsed into POST /credits/subscription as + // a same-tier request. Clicking "Keep BUSINESS" calls useUpdateSubscriptionTier + // with tier === current tier so the backend releases any pending schedule. + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "", tier: "BUSINESS" } }); + const refetchFn = vi.fn(); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + refetchFn, + }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /keep business/i })); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "BUSINESS" }), + }), + ); + expect(refetchFn).toHaveBeenCalled(); + }); + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Pending subscription change cancelled.", + }), + ); + }); + + it("uses end-of-period copy for paid→paid downgrade confirmation", () => { + setupMocks({ subscription: makeSubscription({ tier: "BUSINESS" }) }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to pro/i })); + + const dialog = screen.getByRole("dialog"); + expect(dialog.textContent).toMatch( + /switching to pro will take effect at the end of your current billing period/i, + ); + expect(dialog.textContent).toMatch( + /you keep your current plan until then/i, + ); + expect(dialog.textContent).not.toMatch(/take effect immediately/i); + }); + + it("shows destructive toast, tierError and still refetches when cancel-pending fails", async () => { + // The catch branch inside cancelPendingChange is load-bearing: it surfaces + // the error to the user AND re-issues a refetch so the UI reconciles if + // the server actually succeeded (webhook delivered after our client-side + // error). + const mutateFn = vi + .fn() + .mockRejectedValue(new Error("Stripe webhook failed")); + const refetchFn = vi.fn(); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + refetchFn, + }); + render(); + + const keepButtons = screen.getAllByRole("button", { + name: /keep business/i, + }); + fireEvent.click(keepButtons[0]); + + await waitFor(() => { + expect(screen.getByRole("alert")).toBeDefined(); + expect(screen.getByText(/stripe webhook failed/i)).toBeDefined(); + }); + expect(mockToast).toHaveBeenCalledWith( + expect.objectContaining({ + title: "Failed to cancel pending change", + variant: "destructive", + }), + ); + expect(refetchFn).toHaveBeenCalled(); + }); + + it("disables the tier button that matches the pending tier so users can't overwrite their own scheduled change by mis-click", () => { + // User is on BUSINESS and has a pending downgrade to PRO. The "Downgrade + // to Pro" button must be disabled + labelled "Scheduled" so the primary + // cancel path stays the banner. Other tier buttons (FREE here) remain + // clickable — the user can still overwrite their pending change by + // picking a different target; backend handles that. + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + }); + render(); + + const scheduledBtn = screen.getByRole("button", { name: /scheduled/i }); + expect(scheduledBtn).toBeDefined(); + expect((scheduledBtn as HTMLButtonElement).disabled).toBe(true); + + // The non-pending tier (FREE) button is still clickable. + const freeBtn = screen.getByRole("button", { name: /downgrade to free/i }); + expect((freeBtn as HTMLButtonElement).disabled).toBe(false); + }); + + it("shows replace-pending dialog when clicking a non-pending tier while a pending change exists, and fires the mutation after confirm", async () => { + // User is on BUSINESS with a pending downgrade to PRO. Clicking FREE (a + // tier that is neither current nor the pending target) must NOT silently + // overwrite the pending schedule — it must open a confirmation dialog. + // Only after the user explicitly confirms should changeTier (→ its own + // downgrade confirm for paid→FREE) fire. + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + }); + render(); + + // Clicking FREE while PRO is pending surfaces the replace-pending dialog + // before anything mutates. + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + expect(screen.getByRole("dialog")).toBeDefined(); + expect(screen.getByText(/replace pending change/i)).toBeDefined(); + expect(mutateFn).not.toHaveBeenCalled(); + + // Confirm the replace: the replace-pending dialog closes and the + // downgrade-to-FREE dialog takes over (because FREE is a downgrade). + fireEvent.click( + screen.getByRole("button", { name: /replace pending change/i }), + ); + + // Now the "Confirm Downgrade" dialog should be open — confirm it to fire + // the mutation. + fireEvent.click(screen.getByRole("button", { name: /confirm downgrade/i })); + + await waitFor(() => { + expect(mutateFn).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ tier: "FREE" }), + }), + ); + }); + }); + + it("dismisses replace-pending dialog on Cancel without mutating", () => { + const mutateFn = vi + .fn() + .mockResolvedValue({ status: 200, data: { url: "" } }); + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "PRO", + pendingTierEffectiveAt: new Date("2026-11-15T00:00:00Z"), + }), + mutateFn, + }); + render(); + + fireEvent.click(screen.getByRole("button", { name: /downgrade to free/i })); + expect(screen.getByRole("dialog")).toBeDefined(); + + fireEvent.click(screen.getByRole("button", { name: /^cancel$/i })); + expect(screen.queryByRole("dialog")).toBeNull(); + expect(mutateFn).not.toHaveBeenCalled(); + }); + + it("renders FREE cancellation copy in banner when pending_tier is FREE", () => { + setupMocks({ + subscription: makeSubscription({ + tier: "BUSINESS", + pendingTier: "FREE", + pendingTierEffectiveAt: new Date("2026-05-15T00:00:00Z"), + }), + }); + render(); + // Cancellation copy — distinct from the generic downgrade phrasing. + expect( + screen.getByText(/scheduled to cancel your subscription on/i), + ).toBeDefined(); + expect(screen.getByText(/May 15, 2026/)).toBeDefined(); + // Must NOT render the "downgrade to" phrasing on FREE cancellation. + expect(screen.queryByText(/scheduled to downgrade to/i)).toBeNull(); + }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx new file mode 100644 index 0000000000..0088ad7666 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/components/PendingChangeBanner/PendingChangeBanner.tsx @@ -0,0 +1,60 @@ +import { Button } from "@/components/ui/button"; +import { formatPendingDate, getTierLabel } from "../../helpers"; + +interface Props { + currentTier: string; + pendingTier: string; + pendingEffectiveAt: Date | string | null | undefined; + onKeepCurrent: () => void; + isBusy: boolean; +} + +export function PendingChangeBanner({ + currentTier, + pendingTier, + pendingEffectiveAt, + onKeepCurrent, + isBusy, +}: Props) { + // Backend invariant: pending_tier_effective_at is always populated when + // pending_tier is set. Bail early if the date is missing so the sentence + // always reads with a date instead of a null-fallback branch. + if (!pendingEffectiveAt) return null; + + const pendingLabel = getTierLabel(pendingTier); + const currentLabel = getTierLabel(currentTier); + const dateText = formatPendingDate(pendingEffectiveAt); + + const isCancellation = pendingTier === "FREE"; + + return ( +
+

+ {isCancellation ? ( + <> + Scheduled to cancel your subscription on{" "} + {dateText}. + + ) : ( + <> + Scheduled to downgrade to{" "} + {pendingLabel} on{" "} + {dateText}. + + )} +

+ +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts new file mode 100644 index 0000000000..fde4674a8b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/helpers.ts @@ -0,0 +1,54 @@ +export interface TierInfo { + key: string; + label: string; + multiplier: string; + description: string; +} + +export const TIERS: TierInfo[] = [ + { + key: "FREE", + label: "Free", + multiplier: "1x", + description: "Base AutoPilot capacity with standard rate limits", + }, + { + key: "PRO", + label: "Pro", + multiplier: "5x", + description: "5x AutoPilot capacity — run 5× more tasks per day/week", + }, + { + key: "BUSINESS", + label: "Business", + multiplier: "20x", + description: "20x AutoPilot capacity — ideal for teams and heavy workloads", + }, +]; + +export const TIER_ORDER = ["FREE", "PRO", "BUSINESS", "ENTERPRISE"]; + +export function formatCost(cents: number, tierKey: string): string { + if (tierKey === "FREE") return "Free"; + if (cents === 0) return "Pricing available soon"; + return `$${(cents / 100).toFixed(2)}/mo`; +} + +export function getTierLabel(tierKey: string): string { + return ( + TIERS.find((t) => t.key === tierKey)?.label ?? + tierKey.charAt(0) + tierKey.slice(1).toLowerCase() + ); +} + +export function formatPendingDate(value: Date | string): string { + const date = value instanceof Date ? value : new Date(value); + // Pin to en-US so SSR and CSR produce the same string — passing `undefined` + // picks up the server's locale during prerender and the browser's locale on + // hydration, which triggers a React hydration mismatch warning. + return date.toLocaleDateString("en-US", { + year: "numeric", + month: "short", + day: "numeric", + }); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts index 862551c7e3..d51a2a6051 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/components/SubscriptionTierSection/useSubscriptionTierSection.ts @@ -117,6 +117,47 @@ export function useSubscriptionTierSection() { await changeTier(tier); } + async function cancelPendingChange() { + if (!subscription) return; + setTierError(null); + try { + // "Stay on my current tier" is a same-tier POST: the backend collapses + // cancel-pending into update-tier and releases any pending schedule. + // success_url/cancel_url are unused in this branch (no Stripe Checkout + // is created) but are sent to satisfy the request schema. + await doUpdateTier({ + data: { + tier: subscription.tier as SubscriptionTierRequestTier, + success_url: `${window.location.origin}${window.location.pathname}`, + cancel_url: `${window.location.origin}${window.location.pathname}`, + }, + }); + await refetch(); + toast({ + title: "Pending subscription change cancelled.", + }); + } catch (e: unknown) { + const msg = + e instanceof Error + ? e.message + : "Failed to cancel pending subscription change"; + setTierError(msg); + toast({ + title: "Failed to cancel pending change", + description: msg, + variant: "destructive", + }); + // Refetch on error so the UI reconciles if the server actually + // succeeded (e.g. webhook delivered after our client-side error). + // Swallow refetch errors — we already have the primary error for display. + try { + await refetch(); + } catch { + // intentional + } + } + } + const pendingTier = isPending && variables?.data?.tier ? variables.data.tier : null; @@ -133,5 +174,6 @@ export function useSubscriptionTierSection() { isPaymentEnabled, changeTier, handleTierChange, + cancelPendingChange, }; } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 920348db25..f20f34a805 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -2470,7 +2470,7 @@ }, "post": { "tags": ["v1", "credits"], - "summary": "Start a Stripe Checkout session to upgrade subscription tier", + "summary": "Update subscription tier or start a Stripe Checkout session", "operationId": "updateSubscriptionTier", "requestBody": { "content": { @@ -2488,7 +2488,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/SubscriptionCheckoutResponse" + "$ref": "#/components/schemas/SubscriptionStatusResponse" } } } @@ -14208,12 +14208,6 @@ "enum": ["DRAFT", "PENDING", "APPROVED", "REJECTED"], "title": "SubmissionStatus" }, - "SubscriptionCheckoutResponse": { - "properties": { "url": { "type": "string", "title": "Url" } }, - "type": "object", - "required": ["url"], - "title": "SubscriptionCheckoutResponse" - }, "SubscriptionStatusResponse": { "properties": { "tier": { @@ -14230,6 +14224,26 @@ "proration_credit_cents": { "type": "integer", "title": "Proration Credit Cents" + }, + "pending_tier": { + "anyOf": [ + { "type": "string", "enum": ["FREE", "PRO", "BUSINESS"] }, + { "type": "null" } + ], + "title": "Pending Tier" + }, + "pending_tier_effective_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Pending Tier Effective At" + }, + "url": { + "type": "string", + "title": "Url", + "description": "Populated only when POST /credits/subscription starts a Stripe Checkout Session (FREE → paid upgrade). Empty string in all other branches — the client redirects to this URL when non-empty.", + "default": "" } }, "type": "object", From 01f1289aac2e8408adbf2aa50d5fa5b2344ec488 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 14:34:43 +0700 Subject: [PATCH 04/41] feat(copilot): real OpenRouter cost + cost-based rate limits (percent-only public API) (#12864) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why After d7653acd0 removed cost estimation, most baseline turns log with `tracking_type="tokens"` and no authoritative USD figure (see: dashboard flipped from `cost_usd` to `tokens` after 4/14/2026). Rate-limit counters were also token-weighted with hand-rolled cache discounts (cache_read @ 10%, cache_create @ 25%) and a 5× Opus multiplier — a proxy for cost that drifts from real OpenRouter billing. This PR wires real generation cost from OpenRouter into both the cost-tracking log and the rate limiter, and hides raw spend figures from the user-facing API so clients can't reverse-engineer per-turn cost or platform margins. ## What 1. **Real cost from OpenRouter** — baseline passes `extra_body={"usage": {"include": True}}` and reads `chunk.usage.cost` from the final streaming chunk. `x-total-cost` header path removed. Missing cost logs an error and skips the counter update (vs the old estimator that silently under-counted). 2. **Cost-based rate limiting** — `record_token_usage(...)` → `record_cost_usage(cost_microdollars)`. The weighted-token math, cache discount factors, and `_OPUS_COST_MULTIPLIER` are gone; real USD already reflects model + cache pricing. 3. **Redis key migration** — `copilot:usage:*` → `copilot:cost:*` so stale token counters can't be misinterpreted as microdollars. 4. **LD flags + config** — renamed to `copilot-daily-cost-limit-microdollars` / `copilot-weekly-cost-limit-microdollars` (unit in the LD key so values can't accidentally be set in dollars or cents). 5. **Public `/usage` hides raw $$** — new `CoPilotUsagePublic` / `UsageWindowPublic` schemas expose only `percent_used` (0-100) + `resets_at` + `tier` + `reset_cost`. Admin endpoint keeps raw microdollars for debugging. 6. **Admin API contract** — `UserRateLimitResponse` fields renamed `daily/weekly_token_limit` → `daily/weekly_cost_limit_microdollars`, `daily/weekly_tokens_used` → `daily/weekly_cost_used_microdollars`. Admin UI displays `$X.XX`. ## How - `baseline/service.py` — pass `extra_body`, extract cost from `chunk.usage.cost`, drop the `x-total-cost` header fallback entirely. - `rate_limit.py` — rewritten around `record_cost_usage`, `check_rate_limit(daily_cost_limit, weekly_cost_limit)`, new Redis key prefix. Adds `CoPilotUsagePublic.from_status()` projector for the public API. - `token_tracking.py` — converts `cost_usd` → microdollars via `usd_to_microdollars` and calls `record_cost_usage` only when cost is present. - `sdk/service.py` — deletes `_OPUS_COST_MULTIPLIER` and simplifies `_resolve_model_and_multiplier` to `_resolve_sdk_model_for_request`. - Chat routes: `/usage` and `/usage/reset` return `CoPilotUsagePublic`. Internal server-side limit checks still use the raw microdollar `CoPilotUsageStatus`. - Admin routes: unchanged response shape (renamed fields only). - Frontend: `UsagePanelContent`, `UsageLimits`, `CopilotPage`, `BriefingTabContent`, `credits/page.tsx` consume the new public schema and render "N% used" + progress bar. Admin `RateLimitDisplay` / `UsageBar` keep `$X.XX`. Helper `formatMicrodollarsAsUsd` retained for admin use. - Tests + snapshots rewritten; new assertions explicitly check that raw `used`/`limit` keys are absent from the public payload. ## Deploy notes 1. **Before rolling this out, create the new LD flags:** `copilot-daily-cost-limit-microdollars` (default `500000`) and `copilot-weekly-cost-limit-microdollars` (default `2500000`). Old `copilot-*-token-limit` flags can stay in LD for rollback. 2. **One-time Redis cleanup (optional):** token-based counters under `copilot:usage:*` are orphaned and will TTL out within 7 days. Safe to ignore or delete manually. ## Test plan - [x] `poetry run test` — all impacted backend tests pass (182/182 in targeted scope) - [x] `pnpm test:unit` — all 1628 integration tests pass - [x] `poetry run format` / `pnpm format` / `pnpm types` clean - [x] Manual sanity against dev env — Baseline turn logged $0.1221 for 40K/139 tokens on Sonnet 4 (matches expected pricing) - [ ] `/pr-test --fix` end-to-end against local native stack --- .../features/admin/rate_limit_admin_routes.py | 32 +- .../admin/rate_limit_admin_routes_test.py | 18 +- .../backend/api/features/chat/routes.py | 58 ++- .../backend/api/features/chat/routes_test.py | 40 +- .../backend/copilot/baseline/service.py | 106 +++- .../copilot/baseline/service_unit_test.py | 476 +++++++++--------- .../backend/backend/copilot/config.py | 32 +- .../backend/backend/copilot/rate_limit.py | 270 +++++----- .../backend/copilot/rate_limit_test.py | 100 ++-- .../backend/copilot/reset_usage_test.py | 12 +- .../backend/backend/copilot/sdk/service.py | 37 +- .../backend/backend/copilot/token_tracking.py | 83 +-- .../backend/copilot/token_tracking_test.py | 100 ++-- .../backend/backend/util/feature_flag.py | 4 +- .../backend/snapshots/get_rate_limit | 8 +- .../reset_user_usage_daily_and_weekly | 8 +- .../snapshots/reset_user_usage_daily_only | 8 +- .../(platform)/admin/components/UsageBar.tsx | 10 +- .../components/__tests__/UsageBar.test.tsx | 31 ++ .../components/RateLimitDisplay.tsx | 17 +- .../__tests__/RateLimitDisplay.test.tsx | 18 +- .../__tests__/RateLimitManager.test.tsx | 16 +- .../__tests__/useRateLimitManager.test.ts | 20 +- .../app/(platform)/copilot/CopilotPage.tsx | 8 +- .../copilot/__tests__/CopilotPage.test.tsx | 22 +- .../components/UsageLimits/UsageLimits.tsx | 10 +- .../UsageLimits/UsagePanelContent.tsx | 50 +- .../__tests__/UsageLimits.test.tsx | 75 +-- .../UsagePanelContentRender.test.tsx | 68 ++- .../components/__tests__/usageHelpers.test.ts | 76 +++ .../copilot/components/usageHelpers.ts | 6 + .../AgentBriefingPanel/BriefingTabContent.tsx | 58 +-- .../__tests__/BriefingTabContent.test.tsx | 212 ++++++++ .../profile/(user)/credits/page.tsx | 10 +- .../frontend/src/app/api/openapi.json | 80 +-- 35 files changed, 1330 insertions(+), 849 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/components/__tests__/UsageBar.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py index 379b9e9257..3b9c762f21 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes.py @@ -32,10 +32,10 @@ router = APIRouter( class UserRateLimitResponse(BaseModel): user_id: str user_email: Optional[str] = None - daily_token_limit: int - weekly_token_limit: int - daily_tokens_used: int - weekly_tokens_used: int + daily_cost_limit_microdollars: int + weekly_cost_limit_microdollars: int + daily_cost_used_microdollars: int + weekly_cost_used_microdollars: int tier: SubscriptionTier @@ -101,17 +101,19 @@ async def get_user_rate_limit( logger.info("Admin %s checking rate limit for user %s", admin_user_id, resolved_id) daily_limit, weekly_limit, tier = await get_global_rate_limits( - resolved_id, config.daily_token_limit, config.weekly_token_limit + resolved_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(resolved_id, daily_limit, weekly_limit, tier=tier) return UserRateLimitResponse( user_id=resolved_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) @@ -141,7 +143,9 @@ async def reset_user_rate_limit( raise HTTPException(status_code=500, detail="Failed to reset usage") from e daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) usage = await get_usage_status(user_id, daily_limit, weekly_limit, tier=tier) @@ -154,10 +158,10 @@ async def reset_user_rate_limit( return UserRateLimitResponse( user_id=user_id, user_email=resolved_email, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, - daily_tokens_used=usage.daily.used, - weekly_tokens_used=usage.weekly.used, + daily_cost_limit_microdollars=daily_limit, + weekly_cost_limit_microdollars=weekly_limit, + daily_cost_used_microdollars=usage.daily.used, + weekly_cost_used_microdollars=usage.weekly.used, tier=tier, ) diff --git a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py index 77e4a656fb..c6c920829d 100644 --- a/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py +++ b/autogpt_platform/backend/backend/api/features/admin/rate_limit_admin_routes_test.py @@ -85,10 +85,10 @@ def test_get_rate_limit( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 - assert data["weekly_token_limit"] == 12_500_000 - assert data["daily_tokens_used"] == 500_000 - assert data["weekly_tokens_used"] == 3_000_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 + assert data["weekly_cost_limit_microdollars"] == 12_500_000 + assert data["daily_cost_used_microdollars"] == 500_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 assert data["tier"] == "FREE" configured_snapshot.assert_match( @@ -117,7 +117,7 @@ def test_get_rate_limit_by_email( data = response.json() assert data["user_id"] == target_user_id assert data["user_email"] == _TARGET_EMAIL - assert data["daily_token_limit"] == 2_500_000 + assert data["daily_cost_limit_microdollars"] == 2_500_000 def test_get_rate_limit_by_email_not_found( @@ -160,9 +160,9 @@ def test_reset_user_usage_daily_only( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 # Weekly is untouched - assert data["weekly_tokens_used"] == 3_000_000 + assert data["weekly_cost_used_microdollars"] == 3_000_000 assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=False) @@ -192,8 +192,8 @@ def test_reset_user_usage_daily_and_weekly( assert response.status_code == 200 data = response.json() - assert data["daily_tokens_used"] == 0 - assert data["weekly_tokens_used"] == 0 + assert data["daily_cost_used_microdollars"] == 0 + assert data["weekly_cost_used_microdollars"] == 0 assert data["tier"] == "FREE" mock_reset.assert_awaited_once_with(target_user_id, reset_weekly=True) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index eceedb828c..6ef15f0999 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -34,7 +34,7 @@ from backend.copilot.pending_message_helpers import ( ) from backend.copilot.pending_messages import peek_pending_messages from backend.copilot.rate_limit import ( - CoPilotUsageStatus, + CoPilotUsagePublic, RateLimitExceeded, acquire_reset_lock, check_rate_limit, @@ -536,23 +536,27 @@ async def get_session( ) async def get_copilot_usage( user_id: Annotated[str, Security(auth.get_user_id)], -) -> CoPilotUsageStatus: +) -> CoPilotUsagePublic: """Get CoPilot usage status for the authenticated user. - Returns current token usage vs limits for daily and weekly windows. - Global defaults sourced from LaunchDarkly (falling back to config). - Includes the user's rate-limit tier. + Returns the percentage of the daily/weekly allowance used — not the + raw spend or cap — so clients cannot derive per-turn cost or platform + margins. Global defaults sourced from LaunchDarkly (falling back to + config). Includes the user's rate-limit tier. """ daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) - return await get_usage_status( + status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) + return CoPilotUsagePublic.from_status(status) class RateLimitResetResponse(BaseModel): @@ -561,7 +565,9 @@ class RateLimitResetResponse(BaseModel): success: bool credits_charged: int = Field(description="Credits charged (in cents)") remaining_balance: int = Field(description="Credit balance after charge (in cents)") - usage: CoPilotUsageStatus = Field(description="Updated usage status after reset") + usage: CoPilotUsagePublic = Field( + description="Updated usage status after reset (percentages only)" + ) @router.post( @@ -585,7 +591,7 @@ async def reset_copilot_usage( ) -> RateLimitResetResponse: """Reset the daily CoPilot rate limit by spending credits. - Allows users who have hit their daily token limit to spend credits + Allows users who have hit their daily cost limit to spend credits to reset their daily usage counter and continue working. Returns 400 if the feature is disabled or the user is not over the limit. Returns 402 if the user has insufficient credits. @@ -604,7 +610,9 @@ async def reset_copilot_usage( ) daily_limit, weekly_limit, tier = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) if daily_limit <= 0: @@ -641,8 +649,8 @@ async def reset_copilot_usage( # used for limit checks, not returned to the client.) usage_status = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, tier=tier, ) if daily_limit > 0 and usage_status.daily.used < daily_limit: @@ -677,7 +685,7 @@ async def reset_copilot_usage( # Reset daily usage in Redis. If this fails, refund the credits # so the user is not charged for a service they did not receive. - if not await reset_daily_usage(user_id, daily_token_limit=daily_limit): + if not await reset_daily_usage(user_id, daily_cost_limit=daily_limit): # Compensate: refund the charged credits. refunded = False try: @@ -713,11 +721,11 @@ async def reset_copilot_usage( finally: await release_reset_lock(user_id) - # Return updated usage status. + # Return updated usage status (public schema — percentages only). updated_usage = await get_usage_status( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, rate_limit_reset_cost=config.rate_limit_reset_cost, tier=tier, ) @@ -726,7 +734,7 @@ async def reset_copilot_usage( success=True, credits_charged=cost, remaining_balance=remaining, - usage=updated_usage, + usage=CoPilotUsagePublic.from_status(updated_usage), ) @@ -787,7 +795,7 @@ async def cancel_session_task( ), }, 404: {"description": "Session not found or access denied"}, - 429: {"description": "Token rate-limit or call-frequency cap exceeded"}, + 429: {"description": "Cost rate-limit or call-frequency cap exceeded"}, }, ) async def stream_chat_post( @@ -861,18 +869,20 @@ async def stream_chat_post( }, ) - # Pre-turn rate limit check (token-based). + # Pre-turn rate limit check (cost-based, microdollars). # check_rate_limit short-circuits internally when both limits are 0. # Global defaults sourced from LaunchDarkly, falling back to config. if user_id: try: daily_limit, weekly_limit, _ = await get_global_rate_limits( - user_id, config.daily_token_limit, config.weekly_token_limit + user_id, + config.daily_cost_limit_microdollars, + config.weekly_cost_limit_microdollars, ) await check_rate_limit( user_id=user_id, - daily_token_limit=daily_limit, - weekly_token_limit=weekly_limit, + daily_cost_limit=daily_limit, + weekly_cost_limit=weekly_limit, ) except RateLimitExceeded as e: raise HTTPException(status_code=429, detail=str(e)) from e diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 4dc6547515..88c4ef5f14 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -296,8 +296,8 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerF _mock_stream_internals(mocker) # Ensure the rate-limit branch is entered by setting a non-zero limit. - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded("daily", datetime.now(UTC) + timedelta(hours=1)), @@ -318,8 +318,8 @@ def test_stream_chat_returns_429_on_weekly_rate_limit( from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) resets_at = datetime.now(UTC) + timedelta(days=3) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", @@ -341,8 +341,8 @@ def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): from backend.copilot.rate_limit import RateLimitExceeded _mock_stream_internals(mocker) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) mocker.patch( "backend.api.features.chat.routes.check_rate_limit", side_effect=RateLimitExceeded( @@ -402,23 +402,33 @@ def test_usage_returns_daily_and_weekly( mocker: pytest_mock.MockerFixture, test_user_id: str, ) -> None: - """GET /usage returns daily and weekly usage.""" + """GET /usage returns percentages for daily and weekly windows only. + + The raw used/limit microdollar values MUST NOT leak — clients should not + be able to derive per-turn cost or platform margins from the public API. + """ mock_get = _mock_usage(mocker, daily_used=500, weekly_used=2000) - mocker.patch.object(chat_routes.config, "daily_token_limit", 10000) - mocker.patch.object(chat_routes.config, "weekly_token_limit", 50000) + mocker.patch.object(chat_routes.config, "daily_cost_limit_microdollars", 10000) + mocker.patch.object(chat_routes.config, "weekly_cost_limit_microdollars", 50000) response = client.get("/usage") assert response.status_code == 200 data = response.json() - assert data["daily"]["used"] == 500 - assert data["weekly"]["used"] == 2000 + # 500 / 10000 = 5%, 2000 / 50000 = 4% + assert data["daily"]["percent_used"] == 5.0 + assert data["weekly"]["percent_used"] == 4.0 + # Raw spend/limit must not be exposed. + assert "used" not in data["daily"] + assert "limit" not in data["daily"] + assert "used" not in data["weekly"] + assert "limit" not in data["weekly"] mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=10000, - weekly_token_limit=50000, + daily_cost_limit=10000, + weekly_cost_limit=50000, rate_limit_reset_cost=chat_routes.config.rate_limit_reset_cost, tier=SubscriptionTier.FREE, ) @@ -438,8 +448,8 @@ def test_usage_uses_config_limits( assert response.status_code == 200 mock_get.assert_called_once_with( user_id=test_user_id, - daily_token_limit=99999, - weekly_token_limit=77777, + daily_cost_limit=99999, + weekly_cost_limit=77777, rate_limit_reset_cost=500, tier=SubscriptionTier.FREE, ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 7d27beac8b..8a26002e25 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -22,7 +22,9 @@ from typing import TYPE_CHECKING, Any, cast import orjson from langfuse import propagate_attributes +from openai.types import CompletionUsage from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam +from openai.types.completion_usage import PromptTokensDetails from opentelemetry import trace as otel_trace from backend.copilot.config import CopilotLlmModel, CopilotMode @@ -126,6 +128,53 @@ _MAX_INLINE_IMAGE_BYTES = 20 * 1024 * 1024 # Matches characters unsafe for filenames. _UNSAFE_FILENAME = re.compile(r"[^\w.\-]") +# OpenRouter-specific extra_body flag that embeds the real generation cost +# into the final usage chunk. Module-level constant so we don't reallocate +# an identical dict on every streaming call. +_OPENROUTER_INCLUDE_USAGE_COST = {"usage": {"include": True}} + + +def _extract_usage_cost(usage: CompletionUsage) -> float | None: + """Return the provider-reported USD cost on a streaming usage chunk. + + OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible usage + object when the request body includes ``usage: {"include": True}``. + The OpenAI SDK's typed ``CompletionUsage`` does not declare it, so we + read it off ``model_extra`` (the pydantic v2 container for extras) to + keep the access fully typed — no ``getattr``. + + Returns ``None`` when the field is absent, explicitly null, + non-numeric, non-finite, or negative. Invalid values (including + present-but-null) are logged here — they indicate a provider bug + worth chasing; plain absences are silent so the caller can dedupe + the "missing cost" warning per stream. + """ + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[Baseline] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[Baseline] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[Baseline] usage.cost is non-finite or negative: %r", val) + return None + return val + + +def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: + """Read Anthropic's ``cache_creation_input_tokens`` off an OpenAI + ``PromptTokensDetails`` — it's a provider-specific extra, not in the + typed model, so we read it via ``model_extra`` rather than + ``getattr``. + """ + return int((ptd.model_extra or {}).get("cache_creation_input_tokens") or 0) + async def _prepare_baseline_attachments( file_ids: list[str], @@ -267,6 +316,10 @@ class _BaselineStreamState: turn_cache_read_tokens: int = 0 turn_cache_creation_tokens: int = 0 cost_usd: float | None = None + # Tracks whether we've already warned about a missing `cost` field in + # the usage chunk this stream, so non-OpenRouter providers don't + # generate one warning per streaming call. + cost_missing_logged: bool = False thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) session_messages: list[ChatMessage] = field(default_factory=list) # Tracks how much of ``assistant_text`` has already been flushed to @@ -292,10 +345,12 @@ async def _baseline_llm_caller( state.thinking_stripper = _ThinkingStripper() round_text = "" - response = None # initialized before try so finally block can access it try: client = _get_openai_client() typed_messages = cast(list[ChatCompletionMessageParam], messages) + # extra_body `usage.include=true` asks OpenRouter to embed the real + # generation cost into the final usage chunk. Without this we only get + # token counts and have no authoritative cost for rate limiting. if tools: typed_tools = cast(list[ChatCompletionToolParam], tools) response = await client.chat.completions.create( @@ -304,6 +359,7 @@ async def _baseline_llm_caller( tools=typed_tools, stream=True, stream_options={"include_usage": True}, + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) else: response = await client.chat.completions.create( @@ -311,6 +367,7 @@ async def _baseline_llm_caller( messages=typed_messages, stream=True, stream_options={"include_usage": True}, + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) tool_calls_by_index: dict[int, dict[str, str]] = {} @@ -323,18 +380,33 @@ async def _baseline_llm_caller( if chunk.usage: state.turn_prompt_tokens += chunk.usage.prompt_tokens or 0 state.turn_completion_tokens += chunk.usage.completion_tokens or 0 - # Extract cache token details when available (OpenAI / - # OpenRouter include these in prompt_tokens_details). - ptd = getattr(chunk.usage, "prompt_tokens_details", None) + ptd = chunk.usage.prompt_tokens_details if ptd: - state.turn_cache_read_tokens += ( - getattr(ptd, "cached_tokens", 0) or 0 - ) - # cache_creation_input_tokens is reported by some providers - # (e.g. Anthropic native) but not standard OpenAI streaming. + state.turn_cache_read_tokens += ptd.cached_tokens or 0 state.turn_cache_creation_tokens += ( - getattr(ptd, "cache_creation_input_tokens", 0) or 0 + _extract_cache_creation_tokens(ptd) ) + cost = _extract_usage_cost(chunk.usage) + if cost is not None: + state.cost_usd = (state.cost_usd or 0.0) + cost + elif ( + "cost" not in (chunk.usage.model_extra or {}) + and not state.cost_missing_logged + ): + # Field absent (non-OpenRouter route, or OpenRouter + # misconfigured) — warn once per stream so error + # monitoring picks up persistent misses without + # flooding. Invalid values already logged inside + # _extract_usage_cost, so no duplicate warning here. + logger.warning( + "[Baseline] usage chunk missing cost (model=%s, " + "prompt=%s, completion=%s) — rate-limit will " + "skip this call", + state.model, + chunk.usage.prompt_tokens, + chunk.usage.completion_tokens, + ) + state.cost_missing_logged = True delta = chunk.choices[0].delta if chunk.choices else None if not delta: @@ -394,20 +466,6 @@ async def _baseline_llm_caller( state.text_started = False state.text_block_id = str(uuid.uuid4()) finally: - # Extract OpenRouter cost from response headers (in finally so we - # capture cost even when the stream errors mid-way — we already paid). - # Accumulate across multi-round tool-calling turns. - try: - # Access undocumented _response attribute — same pattern as - # extract_openrouter_cost() in blocks/llm.py. - cost_header = response._response.headers.get("x-total-cost") # type: ignore[attr-defined] - if cost_header: - cost = float(cost_header) - if math.isfinite(cost) and cost >= 0: - state.cost_usd = (state.cost_usd or 0.0) + cost - except (AttributeError, ValueError): - pass - # Always persist partial text so the session history stays consistent, # even when the stream is interrupted by an exception. state.assistant_text += round_text diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index a0e55d843f..e21618c367 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -11,6 +11,7 @@ from openai.types.chat import ChatCompletionToolParam from backend.copilot.baseline.service import ( _baseline_conversation_updater, + _baseline_llm_caller, _BaselineStreamState, _compress_session_messages, ) @@ -574,37 +575,80 @@ class TestPrepareBaselineAttachments: assert blocks == [] +_COST_MISSING = object() + + +def _make_usage_chunk( + *, + prompt_tokens: int = 0, + completion_tokens: int = 0, + cost: float | str | None | object = _COST_MISSING, + cached_tokens: int | None = None, + cache_creation_input_tokens: int | None = None, +): + """Build a mock streaming chunk carrying usage (and optionally cost). + + Provider-specific fields (``cost`` on usage, ``cache_creation_input_tokens`` + on prompt_tokens_details) are set on ``model_extra`` because that's where + the baseline helper reads them from (typed ``CompletionUsage.model_extra`` + rather than ``getattr``). Pass ``cost=None`` to emit an explicit-null cost + key; omit ``cost`` entirely to leave the key absent. + """ + chunk = MagicMock() + chunk.choices = [] + chunk.usage = MagicMock() + chunk.usage.prompt_tokens = prompt_tokens + chunk.usage.completion_tokens = completion_tokens + usage_extras: dict[str, float | str | None] = {} + if cost is not _COST_MISSING: + usage_extras["cost"] = cost # type: ignore[assignment] + chunk.usage.model_extra = usage_extras + + if cached_tokens is not None or cache_creation_input_tokens is not None: + ptd = MagicMock() + ptd.cached_tokens = cached_tokens or 0 + ptd.model_extra = { + "cache_creation_input_tokens": cache_creation_input_tokens or 0 + } + chunk.usage.prompt_tokens_details = ptd + else: + chunk.usage.prompt_tokens_details = None + + return chunk + + +def _make_stream_mock(*chunks): + """Build an async streaming response mock that yields *chunks* in order.""" + stream = MagicMock() + stream.close = AsyncMock() + + async def aiter(): + for c in chunks: + yield c + + stream.__aiter__ = lambda self: aiter() + return stream + + class TestBaselineCostExtraction: - """Tests for x-total-cost header extraction in _baseline_llm_caller.""" + """Tests for ``usage.cost`` extraction in ``_baseline_llm_caller``. + + Cost is read from the OpenRouter ``usage.cost`` field on the final + streaming chunk when the request body includes ``usage: {include: true}`` + (handled by the baseline service via ``extra_body``). + """ @pytest.mark.asyncio - async def test_cost_usd_extracted_from_response_header(self): - """state.cost_usd is set from x-total-cost header when present.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_extracted_from_usage_chunk(self): + """state.cost_usd is set from chunk.usage.cost when present.""" state = _BaselineStreamState(model="gpt-4o-mini") - - # Build a mock raw httpx response with the cost header - mock_raw_response = MagicMock() - mock_raw_response.headers = {"x-total-cost": "0.0123"} - - # Build a mock async streaming response that yields no chunks but has - # a _response attribute pointing to the mock httpx response - mock_stream_response = MagicMock() - mock_stream_response._response = mock_raw_response - - async def empty_aiter(): - return - yield # make it an async generator - - mock_stream_response.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk( + prompt_tokens=1000, completion_tokens=200, cost=0.0123 + ) mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - return_value=mock_stream_response + return_value=_make_stream_mock(chunk) ) with patch( @@ -622,29 +666,14 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cost_usd_accumulates_across_calls(self): """cost_usd accumulates when _baseline_llm_caller is called multiple times.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="gpt-4o-mini") - def make_stream_mock(cost: str) -> MagicMock: - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": cost} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( - side_effect=[make_stream_mock("0.01"), make_stream_mock("0.02")] + side_effect=[ + _make_stream_mock(_make_usage_chunk(prompt_tokens=500, cost=0.01)), + _make_stream_mock(_make_usage_chunk(prompt_tokens=600, cost=0.02)), + ] ) with patch( @@ -665,28 +694,64 @@ class TestBaselineCostExtraction: assert state.cost_usd == pytest.approx(0.03) @pytest.mark.asyncio - async def test_no_cost_when_header_absent(self): - """state.cost_usd remains None when response has no x-total-cost header.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + async def test_cost_usd_accepts_string_value(self): + """OpenRouter may emit cost as a string — it should still parse.""" state = _BaselineStreamState(model="gpt-4o-mini") - - mock_raw = MagicMock() - mock_raw.headers = {} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - async def empty_aiter(): - return - yield - - mock_stream.__aiter__ = lambda self: empty_aiter() + chunk = _make_usage_chunk(prompt_tokens=10, cost="0.005") mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd == pytest.approx(0.005) + + @pytest.mark.asyncio + async def test_cost_usd_none_when_usage_cost_missing(self): + """state.cost_usd stays None when the usage chunk lacks a cost field.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=1000, completion_tokens=500) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + # Token accumulators are still populated so the caller can log them. + assert state.turn_prompt_tokens == 1000 + assert state.turn_completion_tokens == 500 + + @pytest.mark.asyncio + async def test_invalid_cost_string_leaves_cost_none(self): + """A non-numeric cost value is rejected without raising.""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost="not-a-number") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -701,28 +766,73 @@ class TestBaselineCostExtraction: assert state.cost_usd is None @pytest.mark.asyncio - async def test_cost_extracted_even_when_stream_raises(self): - """cost_usd is captured in the finally block even when streaming fails.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + async def test_negative_cost_is_ignored(self): + """Guard against negative cost values (shouldn't happen but be safe).""" + state = _BaselineStreamState(model="gpt-4o-mini") + chunk = _make_usage_chunk(prompt_tokens=10, cost=-0.01) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) ) + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + + @pytest.mark.asyncio + async def test_explicit_null_cost_is_logged_and_ignored(self, caplog): + """`{"cost": null}` is rejected and logged (not silently dropped).""" + state = _BaselineStreamState(model="openrouter/auto") + chunk = _make_usage_chunk(prompt_tokens=10, cost=None) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + caplog.at_level("ERROR", logger="backend.copilot.baseline.service"), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + assert state.cost_usd is None + assert any( + "usage.cost is present but null" in rec.message for rec in caplog.records + ) + + @pytest.mark.asyncio + async def test_cost_not_captured_when_stream_raises_mid_chunk(self): + """If the stream aborts before emitting the usage chunk there is no cost.""" state = _BaselineStreamState(model="gpt-4o-mini") - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.005"} - mock_stream = MagicMock() - mock_stream._response = mock_raw + stream = MagicMock() + stream.close = AsyncMock() async def failing_aiter(): raise RuntimeError("stream error") yield # make it an async generator - mock_stream.__aiter__ = lambda self: failing_aiter() + stream.__aiter__ = lambda self: failing_aiter() mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock(return_value=stream) with ( patch( @@ -737,16 +847,12 @@ class TestBaselineCostExtraction: state=state, ) - assert state.cost_usd == pytest.approx(0.005) + # Stream aborted before yielding the usage chunk — cost stays None. + assert state.cost_usd is None @pytest.mark.asyncio async def test_no_cost_when_api_call_raises_before_stream(self): - """finally block is safe when response is None (API call failed before yielding).""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - + """The helper is safe when the create() call itself raises.""" state = _BaselineStreamState(model="gpt-4o-mini") mock_client = MagicMock() @@ -767,84 +873,23 @@ class TestBaselineCostExtraction: state=state, ) - # response was never assigned so cost extraction must not raise - assert state.cost_usd is None - - @pytest.mark.asyncio - async def test_no_cost_when_header_missing(self): - """cost_usd remains None when x-total-cost is absent.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) - - with patch( - "backend.copilot.baseline.service._get_openai_client", - return_value=mock_client, - ): - await _baseline_llm_caller( - messages=[{"role": "user", "content": "hi"}], - tools=[], - state=state, - ) - assert state.cost_usd is None @pytest.mark.asyncio async def test_cache_tokens_extracted_from_usage_details(self): """cache tokens are extracted from prompt_tokens_details.cached_tokens.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=800, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - # Create a chunk with prompt_tokens_details - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 800 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -861,37 +906,20 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_cache_creation_tokens_extracted_from_usage_details(self): - """cache_creation_tokens are extracted from prompt_tokens_details.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, + """cache_creation_input_tokens is extracted from prompt_tokens_details.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk( + prompt_tokens=1000, + completion_tokens=200, + cost=0.01, + cached_tokens=0, + cache_creation_input_tokens=500, ) - state = _BaselineStreamState(model="openai/gpt-4o") - - mock_raw = MagicMock() - mock_raw.headers = {"x-total-cost": "0.01"} - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_ptd = MagicMock() - mock_ptd.cached_tokens = 0 - mock_ptd.cache_creation_input_tokens = 500 - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 200 - mock_chunk.usage.prompt_tokens_details = mock_ptd - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) with patch( "backend.copilot.baseline.service._get_openai_client", @@ -908,37 +936,17 @@ class TestBaselineCostExtraction: @pytest.mark.asyncio async def test_token_accumulators_track_across_multiple_calls(self): """Token accumulators grow correctly across multiple _baseline_llm_caller calls.""" - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - def make_stream(prompt_tokens: int, completion_tokens: int): - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = prompt_tokens - mock_chunk.usage.completion_tokens = completion_tokens - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - return mock_stream - mock_client = MagicMock() mock_client.chat.completions.create = AsyncMock( side_effect=[ - make_stream(1000, 200), - make_stream(1100, 300), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1000, completion_tokens=200) + ), + _make_stream_mock( + _make_usage_chunk(prompt_tokens=1100, completion_tokens=300) + ), ] ) @@ -957,45 +965,33 @@ class TestBaselineCostExtraction: state=state, ) - # No x-total-cost header and empty pricing table -- cost_usd remains None + # No usage.cost on either chunk → cost stays None, tokens still accumulate. assert state.cost_usd is None - # Accumulators hold all tokens across both turns assert state.turn_prompt_tokens == 2100 assert state.turn_completion_tokens == 500 + @pytest.mark.parametrize( + "tools", + [ + pytest.param([], id="no_tools"), + pytest.param([_make_tool("search")], id="with_tools"), + ], + ) @pytest.mark.asyncio - async def test_cost_usd_remains_none_when_header_missing(self): - """cost_usd stays None when x-total-cost header is absent. + async def test_baseline_requests_usage_include_extra_body( + self, tools: list[ChatCompletionToolParam] + ): + """The baseline call must pass extra_body={'usage': {'include': True}}. - Token counts are still tracked; persist_and_record_usage handles - the None cost by falling back to tracking_type='tokens'. + This guards the contract with OpenRouter that triggers inclusion of + the authoritative cost on the final usage chunk. Without it the + rate-limit counter stays at zero. Exercise both the no-tools and + tool-calling branches so a regression in either path trips the test. """ - from backend.copilot.baseline.service import ( - _baseline_llm_caller, - _BaselineStreamState, - ) - - state = _BaselineStreamState(model="anthropic/claude-sonnet-4") - - mock_raw = MagicMock() - mock_raw.headers = {} # no x-total-cost - mock_stream = MagicMock() - mock_stream._response = mock_raw - - mock_chunk = MagicMock() - mock_chunk.usage = MagicMock() - mock_chunk.usage.prompt_tokens = 1000 - mock_chunk.usage.completion_tokens = 500 - mock_chunk.usage.prompt_tokens_details = None - mock_chunk.choices = [] - - async def chunk_aiter(): - yield mock_chunk - - mock_stream.__aiter__ = lambda self: chunk_aiter() - + state = _BaselineStreamState(model="gpt-4o-mini") + create_mock = AsyncMock(return_value=_make_stream_mock()) mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_stream) + mock_client.chat.completions.create = create_mock with patch( "backend.copilot.baseline.service._get_openai_client", @@ -1003,13 +999,15 @@ class TestBaselineCostExtraction: ): await _baseline_llm_caller( messages=[{"role": "user", "content": "hi"}], - tools=[], + tools=tools, state=state, ) - assert state.cost_usd is None - assert state.turn_prompt_tokens == 1000 - assert state.turn_completion_tokens == 500 + create_mock.assert_awaited_once() + await_args = create_mock.await_args + assert await_args is not None + assert await_args.kwargs["extra_body"] == {"usage": {"include": True}} + assert await_args.kwargs["stream_options"] == {"include_usage": True} class TestMidLoopPendingFlushOrdering: diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index ee4c717dbe..3277854172 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -101,25 +101,31 @@ class ChatConfig(BaseSettings): description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)", ) - # Rate limiting — token-based limits per day and per week. - # Per-turn token cost varies with context size: ~10-15K for early turns, - # ~30-50K mid-session, up to ~100K pre-compaction. Average across a - # session with compaction cycles is ~25-35K tokens/turn, so 2.5M daily - # allows ~70-100 turns/day. + # Rate limiting — cost-based limits per day and per week, stored in + # microdollars (1 USD = 1_000_000). The counter tracks the real + # generation cost reported by the provider (OpenRouter ``usage.cost`` + # or Claude Agent SDK ``total_cost_usd``), so cache discounts and + # cross-model price differences are already reflected — no token + # weighting or model multiplier is applied on top. # Checked at the HTTP layer (routes.py) before each turn. # - # These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS, + # These are base limits for the FREE tier. Higher tiers (PRO, BUSINESS, # ENTERPRISE) multiply these by their tier multiplier (see - # rate_limit.TIER_MULTIPLIERS). User tier is stored in the + # rate_limit.TIER_MULTIPLIERS). User tier is stored in the # User.subscriptionTier DB column and resolved inside # get_global_rate_limits(). - daily_token_limit: int = Field( - default=2_500_000, - description="Max tokens per day, resets at midnight UTC (0 = unlimited)", + # + # These defaults act as the ceiling when LaunchDarkly is unreachable; + # the live per-tier values come from the COPILOT_*_COST_LIMIT flags. + daily_cost_limit_microdollars: int = Field( + default=1_000_000, + description="Max cost per day in microdollars, resets at midnight UTC " + "(0 = unlimited).", ) - weekly_token_limit: int = Field( - default=12_500_000, - description="Max tokens per week, resets Monday 00:00 UTC (0 = unlimited)", + weekly_cost_limit_microdollars: int = Field( + default=5_000_000, + description="Max cost per week in microdollars, resets Monday 00:00 UTC " + "(0 = unlimited).", ) # Cost (in credits / cents) to reset the daily rate limit using credits. diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index c08cb1b3a8..472ddf79b0 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -1,9 +1,16 @@ -"""CoPilot rate limiting based on token usage. +"""CoPilot rate limiting based on generation cost. -Uses Redis fixed-window counters to track per-user token consumption -with configurable daily and weekly limits. Daily windows reset at -midnight UTC; weekly windows reset at ISO week boundary (Monday 00:00 -UTC). Fails open when Redis is unavailable to avoid blocking users. +Uses Redis fixed-window counters to track per-user USD spend (stored as +microdollars, matching ``PlatformCostLog.cost_microdollars``) with +configurable daily and weekly limits. Daily windows reset at midnight UTC; +weekly windows reset at ISO week boundary (Monday 00:00 UTC). Fails open +when Redis is unavailable to avoid blocking users. + +Storing microdollars rather than tokens means the counter already reflects +real model pricing (including cache discounts and provider surcharges), so +this module carries no pricing table — the cost comes from OpenRouter's +``usage.cost`` field (baseline) or the Claude Agent SDK's reported total +cost (SDK path). """ import asyncio @@ -22,8 +29,10 @@ from backend.util.cache import cached logger = logging.getLogger(__name__) -# Redis key prefixes -_USAGE_KEY_PREFIX = "copilot:usage" +# Redis key prefixes. Bumped from "copilot:usage" (token-based) to +# "copilot:cost" on the token→cost migration so stale counters do not +# get misinterpreted as microdollars (which would dramatically under-count). +_USAGE_KEY_PREFIX = "copilot:cost" # --------------------------------------------------------------------------- @@ -32,7 +41,7 @@ _USAGE_KEY_PREFIX = "copilot:usage" class SubscriptionTier(str, Enum): - """Subscription tiers with increasing token allowances. + """Subscription tiers with increasing cost allowances. Mirrors the ``SubscriptionTier`` enum in ``schema.prisma``. Once ``prisma generate`` is run, this can be replaced with:: @@ -46,9 +55,9 @@ class SubscriptionTier(str, Enum): ENTERPRISE = "ENTERPRISE" -# Multiplier applied to the base limits (from LD / config) for each tier. -# Intentionally int (not float): keeps limits as whole token counts and avoids -# floating-point rounding. If fractional multipliers are ever needed, change +# Multiplier applied to the base cost limits (from LD / config) for each tier. +# Intentionally int (not float): keeps limits as whole microdollars and avoids +# floating-point rounding. If fractional multipliers are ever needed, change # the type and round the result in get_global_rate_limits(). TIER_MULTIPLIERS: dict[SubscriptionTier, int] = { SubscriptionTier.FREE: 1, @@ -61,17 +70,27 @@ DEFAULT_TIER = SubscriptionTier.FREE class UsageWindow(BaseModel): - """Usage within a single time window.""" + """Usage within a single time window. + + ``used`` and ``limit`` are in microdollars (1 USD = 1_000_000). + """ used: int limit: int = Field( - description="Maximum tokens allowed in this window. 0 means unlimited." + description="Maximum microdollars of spend allowed in this window. " + "0 means unlimited." ) resets_at: datetime class CoPilotUsageStatus(BaseModel): - """Current usage status for a user across all windows.""" + """Current usage status for a user across all windows. + + Internal representation used by server-side code that needs to compare + usage against limits (e.g. the reset-credits endpoint). The public API + returns ``CoPilotUsagePublic`` instead so that raw spend and limit + figures never leak to clients. + """ daily: UsageWindow weekly: UsageWindow @@ -82,6 +101,68 @@ class CoPilotUsageStatus(BaseModel): ) +class UsageWindowPublic(BaseModel): + """Public view of a usage window — only the percentage and reset time. + + Hides the raw spend and the cap so clients cannot derive per-turn cost + or reverse-engineer platform margins. ``percent_used`` is capped at 100. + """ + + percent_used: float = Field( + ge=0.0, + le=100.0, + description="Percentage of the window's allowance used (0-100). " + "Clamped at 100 when over the cap.", + ) + resets_at: datetime + + +class CoPilotUsagePublic(BaseModel): + """Current usage status for a user — public (client-safe) shape.""" + + daily: UsageWindowPublic | None = Field( + default=None, + description="Null when no daily cap is configured (unlimited).", + ) + weekly: UsageWindowPublic | None = Field( + default=None, + description="Null when no weekly cap is configured (unlimited).", + ) + tier: SubscriptionTier = DEFAULT_TIER + reset_cost: int = Field( + default=0, + description="Credit cost (in cents) to reset the daily limit. 0 = feature disabled.", + ) + + @classmethod + def from_status(cls, status: CoPilotUsageStatus) -> "CoPilotUsagePublic": + """Project the internal status onto the client-safe schema.""" + + def window(w: UsageWindow) -> UsageWindowPublic | None: + if w.limit <= 0: + return None + # When at/over the cap, snap to exactly 100.0 so the UI's + # rounded display and its exhaustion check (`percent_used >= 100`) + # agree. Without this, e.g. 99.95% would render as "100% used" + # via Math.round but fail the exhaustion check, leaving the + # reset button hidden while the bar appears full. + if w.used >= w.limit: + pct = 100.0 + else: + pct = round(100.0 * w.used / w.limit, 1) + return UsageWindowPublic( + percent_used=pct, + resets_at=w.resets_at, + ) + + return cls( + daily=window(status.daily), + weekly=window(status.weekly), + tier=status.tier, + reset_cost=status.reset_cost, + ) + + class RateLimitExceeded(Exception): """Raised when a user exceeds their CoPilot usage limit.""" @@ -103,8 +184,8 @@ class RateLimitExceeded(Exception): async def get_usage_status( user_id: str, - daily_token_limit: int, - weekly_token_limit: int, + daily_cost_limit: int, + weekly_cost_limit: int, rate_limit_reset_cost: int = 0, tier: SubscriptionTier = DEFAULT_TIER, ) -> CoPilotUsageStatus: @@ -112,13 +193,13 @@ async def get_usage_status( Args: user_id: The user's ID. - daily_token_limit: Max tokens per day (0 = unlimited). - weekly_token_limit: Max tokens per week (0 = unlimited). + daily_cost_limit: Max microdollars of spend per day (0 = unlimited). + weekly_cost_limit: Max microdollars of spend per week (0 = unlimited). rate_limit_reset_cost: Credit cost (cents) to reset daily limit (0 = disabled). tier: The user's rate-limit tier (included in the response). Returns: - CoPilotUsageStatus with current usage and limits. + CoPilotUsageStatus with current usage and limits in microdollars. """ now = datetime.now(UTC) daily_used = 0 @@ -137,12 +218,12 @@ async def get_usage_status( return CoPilotUsageStatus( daily=UsageWindow( used=daily_used, - limit=daily_token_limit, + limit=daily_cost_limit, resets_at=_daily_reset_time(now=now), ), weekly=UsageWindow( used=weekly_used, - limit=weekly_token_limit, + limit=weekly_cost_limit, resets_at=_weekly_reset_time(now=now), ), tier=tier, @@ -152,22 +233,22 @@ async def get_usage_status( async def check_rate_limit( user_id: str, - daily_token_limit: int, - weekly_token_limit: int, + daily_cost_limit: int, + weekly_cost_limit: int, ) -> None: """Check if user is within rate limits. Raises RateLimitExceeded if not. This is a pre-turn soft check. The authoritative usage counter is updated - by ``record_token_usage()`` after the turn completes. Under concurrency, + by ``record_cost_usage()`` after the turn completes. Under concurrency, two parallel turns may both pass this check against the same snapshot. - This is acceptable because token-based limits are approximate by nature - (the exact token count is unknown until after generation). + This is acceptable because cost-based limits are approximate by nature + (the exact cost is unknown until after generation). Fails open: if Redis is unavailable, allows the request. """ # Short-circuit: when both limits are 0 (unlimited) skip the Redis # round-trip entirely. - if daily_token_limit <= 0 and weekly_token_limit <= 0: + if daily_cost_limit <= 0 and weekly_cost_limit <= 0: return now = datetime.now(UTC) @@ -183,26 +264,25 @@ async def check_rate_limit( logger.warning("Redis unavailable for rate limit check, allowing request") return - # Worst-case overshoot: N concurrent requests × ~15K tokens each. - if daily_token_limit > 0 and daily_used >= daily_token_limit: + if daily_cost_limit > 0 and daily_used >= daily_cost_limit: raise RateLimitExceeded("daily", _daily_reset_time(now=now)) - if weekly_token_limit > 0 and weekly_used >= weekly_token_limit: + if weekly_cost_limit > 0 and weekly_used >= weekly_cost_limit: raise RateLimitExceeded("weekly", _weekly_reset_time(now=now)) -async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool: - """Reset a user's daily token usage counter in Redis. +async def reset_daily_usage(user_id: str, daily_cost_limit: int = 0) -> bool: + """Reset a user's daily cost usage counter in Redis. Called after a user pays credits to extend their daily limit. - Also reduces the weekly usage counter by ``daily_token_limit`` tokens + Also reduces the weekly usage counter by ``daily_cost_limit`` microdollars (clamped to 0) so the user effectively gets one extra day's worth of weekly capacity. Args: user_id: The user's ID. - daily_token_limit: The configured daily token limit. When positive, - the weekly counter is reduced by this amount. + daily_cost_limit: The configured daily cost limit in microdollars. + When positive, the weekly counter is reduced by this amount. Returns False if Redis is unavailable so the caller can handle compensation (fail-closed for billed operations, unlike the read-only @@ -218,12 +298,12 @@ async def reset_daily_usage(user_id: str, daily_token_limit: int = 0) -> bool: # counter is not decremented — which would let the caller refund # credits even though the daily limit was already reset. d_key = _daily_key(user_id, now=now) - w_key = _weekly_key(user_id, now=now) if daily_token_limit > 0 else None + w_key = _weekly_key(user_id, now=now) if daily_cost_limit > 0 else None pipe = redis.pipeline(transaction=True) pipe.delete(d_key) if w_key is not None: - pipe.decrby(w_key, daily_token_limit) + pipe.decrby(w_key, daily_cost_limit) results = await pipe.execute() # Clamp negative weekly counter to 0 (best-effort; not critical). @@ -296,84 +376,40 @@ async def increment_daily_reset_count(user_id: str) -> None: logger.warning("Redis unavailable for tracking reset count") -async def record_token_usage( +async def record_cost_usage( user_id: str, - prompt_tokens: int, - completion_tokens: int, - *, - cache_read_tokens: int = 0, - cache_creation_tokens: int = 0, - model_cost_multiplier: float = 1.0, + cost_microdollars: int, ) -> None: - """Record token usage for a user across all windows. + """Record a user's generation spend against daily and weekly counters. - Uses cost-weighted counting so cached tokens don't unfairly penalise - multi-turn conversations. Anthropic's pricing: - - uncached input: 100% - - cache creation: 25% - - cache read: 10% - - output: 100% - - ``prompt_tokens`` should be the *uncached* input count (``input_tokens`` - from the API response). Cache counts are passed separately. - - ``model_cost_multiplier`` scales the final weighted total to reflect - relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet) - so that Opus turns deplete the rate limit faster, proportional to cost. + ``cost_microdollars`` is the real generation cost reported by the + provider (OpenRouter's ``usage.cost`` or the Claude Agent SDK's + ``total_cost_usd`` converted to microdollars). Because the provider + cost already reflects model pricing and cache discounts, this function + carries no pricing table or weighting — it just increments counters. Args: user_id: The user's ID. - prompt_tokens: Uncached input tokens. - completion_tokens: Output tokens. - cache_read_tokens: Tokens served from prompt cache (10% cost). - cache_creation_tokens: Tokens written to prompt cache (25% cost). - model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus). + cost_microdollars: Spend to record in microdollars (1 USD = 1_000_000). + Non-positive values are ignored. """ - prompt_tokens = max(0, prompt_tokens) - completion_tokens = max(0, completion_tokens) - cache_read_tokens = max(0, cache_read_tokens) - cache_creation_tokens = max(0, cache_creation_tokens) - - weighted_input = ( - prompt_tokens - + round(cache_creation_tokens * 0.25) - + round(cache_read_tokens * 0.1) - ) - total = round( - (weighted_input + completion_tokens) * max(1.0, model_cost_multiplier) - ) - if total <= 0: + cost_microdollars = max(0, cost_microdollars) + if cost_microdollars <= 0: return - raw_total = ( - prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens - ) - logger.info( - "Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx " - "(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)", - user_id[:8], - raw_total, - total, - model_cost_multiplier, - prompt_tokens, - cache_read_tokens, - cache_creation_tokens, - completion_tokens, - ) + logger.info("Recording copilot spend: %d microdollars", cost_microdollars) now = datetime.now(UTC) try: redis = await get_redis_async() - # transaction=False: these are independent INCRBY+EXPIRE pairs on - # separate keys — no cross-key atomicity needed. Skipping - # MULTI/EXEC avoids the overhead. If the connection drops between - # INCRBY and EXPIRE the key survives until the next date-based key - # rotation (daily/weekly), so the memory-leak risk is negligible. - pipe = redis.pipeline(transaction=False) + # Use MULTI/EXEC so each INCRBY/EXPIRE pair is atomic — guarantees + # the TTL is set even if the connection drops mid-pipeline, so + # counters can never survive past their date-based rotation window. + pipe = redis.pipeline(transaction=True) # Daily counter (expires at next midnight UTC) d_key = _daily_key(user_id, now=now) - pipe.incrby(d_key, total) + pipe.incrby(d_key, cost_microdollars) seconds_until_daily_reset = int( (_daily_reset_time(now=now) - now).total_seconds() ) @@ -381,7 +417,7 @@ async def record_token_usage( # Weekly counter (expires end of week) w_key = _weekly_key(user_id, now=now) - pipe.incrby(w_key, total) + pipe.incrby(w_key, cost_microdollars) seconds_until_weekly_reset = int( (_weekly_reset_time(now=now) - now).total_seconds() ) @@ -390,8 +426,8 @@ async def record_token_usage( await pipe.execute() except (RedisError, ConnectionError, OSError): logger.warning( - "Redis unavailable for recording token usage (tokens=%d)", - total, + "Redis unavailable for recording cost usage (microdollars=%d)", + cost_microdollars, ) @@ -598,37 +634,41 @@ async def get_global_rate_limits( ) -> tuple[int, int, SubscriptionTier]: """Resolve global rate limits from LaunchDarkly, falling back to config. - The base limits (from LD or config) are multiplied by the user's - tier multiplier so that higher tiers receive proportionally larger - allowances. + Values are microdollars. The base limits (from LD or config) are + multiplied by the user's tier multiplier so that higher tiers receive + proportionally larger allowances. Args: user_id: User ID for LD flag evaluation context. - config_daily: Fallback daily limit from ChatConfig. - config_weekly: Fallback weekly limit from ChatConfig. + config_daily: Fallback daily cost limit (microdollars) from ChatConfig. + config_weekly: Fallback weekly cost limit (microdollars) from ChatConfig. Returns: - (daily_token_limit, weekly_token_limit, tier) 3-tuple. + (daily_cost_limit, weekly_cost_limit, tier) — limits in microdollars. """ # Lazy import to avoid circular dependency: # rate_limit -> feature_flag -> settings -> ... -> rate_limit from backend.util.feature_flag import Flag, get_feature_flag_value - daily_raw = await get_feature_flag_value( - Flag.COPILOT_DAILY_TOKEN_LIMIT.value, user_id, config_daily - ) - weekly_raw = await get_feature_flag_value( - Flag.COPILOT_WEEKLY_TOKEN_LIMIT.value, user_id, config_weekly + # Fetch daily + weekly flags in parallel — each LD evaluation is an + # independent network round-trip, so gather cuts latency roughly in half. + daily_raw, weekly_raw = await asyncio.gather( + get_feature_flag_value( + Flag.COPILOT_DAILY_COST_LIMIT.value, user_id, config_daily + ), + get_feature_flag_value( + Flag.COPILOT_WEEKLY_COST_LIMIT.value, user_id, config_weekly + ), ) try: daily = max(0, int(daily_raw)) except (TypeError, ValueError): - logger.warning("Invalid LD value for daily token limit: %r", daily_raw) + logger.warning("Invalid LD value for daily cost limit: %r", daily_raw) daily = config_daily try: weekly = max(0, int(weekly_raw)) except (TypeError, ValueError): - logger.warning("Invalid LD value for weekly token limit: %r", weekly_raw) + logger.warning("Invalid LD value for weekly cost limit: %r", weekly_raw) weekly = config_weekly # Apply tier multiplier diff --git a/autogpt_platform/backend/backend/copilot/rate_limit_test.py b/autogpt_platform/backend/backend/copilot/rate_limit_test.py index 577093c752..3787796c17 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit_test.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit_test.py @@ -24,7 +24,7 @@ from .rate_limit import ( get_usage_status, get_user_tier, increment_daily_reset_count, - record_token_usage, + record_cost_usage, release_reset_lock, reset_daily_usage, reset_user_usage, @@ -82,7 +82,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert isinstance(status, CoPilotUsageStatus) @@ -98,7 +98,7 @@ class TestGetUsageStatus: side_effect=ConnectionError("Redis down"), ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 0 @@ -115,7 +115,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 0 @@ -132,7 +132,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert status.daily.used == 500 @@ -148,7 +148,7 @@ class TestGetUsageStatus: return_value=mock_redis, ): status = await get_usage_status( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) now = datetime.now(UTC) @@ -174,7 +174,7 @@ class TestCheckRateLimit: ): # Should not raise await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) @pytest.mark.asyncio @@ -188,7 +188,7 @@ class TestCheckRateLimit: ): with pytest.raises(RateLimitExceeded) as exc_info: await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert exc_info.value.window == "daily" @@ -203,7 +203,7 @@ class TestCheckRateLimit: ): with pytest.raises(RateLimitExceeded) as exc_info: await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) assert exc_info.value.window == "weekly" @@ -216,7 +216,7 @@ class TestCheckRateLimit: ): # Should not raise await check_rate_limit( - _USER, daily_token_limit=10000, weekly_token_limit=50000 + _USER, daily_cost_limit=10000, weekly_cost_limit=50000 ) @pytest.mark.asyncio @@ -229,15 +229,15 @@ class TestCheckRateLimit: return_value=mock_redis, ): # Should not raise — limits of 0 mean unlimited - await check_rate_limit(_USER, daily_token_limit=0, weekly_token_limit=0) + await check_rate_limit(_USER, daily_cost_limit=0, weekly_cost_limit=0) # --------------------------------------------------------------------------- -# record_token_usage +# record_cost_usage # --------------------------------------------------------------------------- -class TestRecordTokenUsage: +class TestRecordCostUsage: @staticmethod def _make_pipeline_mock() -> MagicMock: """Create a pipeline mock with sync methods and async execute.""" @@ -255,27 +255,40 @@ class TestRecordTokenUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=123_456) - # Should call incrby twice (daily + weekly) with total=150 + # Should call incrby twice (daily + weekly) with the same cost incrby_calls = mock_pipe.incrby.call_args_list assert len(incrby_calls) == 2 - assert incrby_calls[0].args[1] == 150 # daily - assert incrby_calls[1].args[1] == 150 # weekly + assert incrby_calls[0].args[1] == 123_456 # daily + assert incrby_calls[1].args[1] == 123_456 # weekly @pytest.mark.asyncio - async def test_skips_when_zero_tokens(self): + async def test_skips_when_cost_is_zero(self): mock_redis = AsyncMock() with patch( "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=0, completion_tokens=0) + await record_cost_usage(_USER, cost_microdollars=0) # Should not call pipeline at all mock_redis.pipeline.assert_not_called() + @pytest.mark.asyncio + async def test_skips_when_cost_is_negative(self): + """Negative costs are clamped to zero and skip the pipeline.""" + mock_redis = AsyncMock() + + with patch( + "backend.copilot.rate_limit.get_redis_async", + return_value=mock_redis, + ): + await record_cost_usage(_USER, cost_microdollars=-10) + + mock_redis.pipeline.assert_not_called() + @pytest.mark.asyncio async def test_sets_expire_on_both_keys(self): """Pipeline should call expire for both daily and weekly keys.""" @@ -287,7 +300,7 @@ class TestRecordTokenUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=5_000) expire_calls = mock_pipe.expire.call_args_list assert len(expire_calls) == 2 @@ -308,32 +321,7 @@ class TestRecordTokenUsage: side_effect=ConnectionError("Redis down"), ): # Should not raise - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) - - @pytest.mark.asyncio - async def test_cost_weighted_counting(self): - """Cached tokens should be weighted: cache_read=10%, cache_create=25%.""" - mock_pipe = self._make_pipeline_mock() - mock_redis = AsyncMock() - mock_redis.pipeline = lambda **_kw: mock_pipe - - with patch( - "backend.copilot.rate_limit.get_redis_async", - return_value=mock_redis, - ): - await record_token_usage( - _USER, - prompt_tokens=100, # uncached → 100 - completion_tokens=50, # output → 50 - cache_read_tokens=10000, # 10% → 1000 - cache_creation_tokens=400, # 25% → 100 - ) - - # Expected weighted total: 100 + 1000 + 100 + 50 = 1250 - incrby_calls = mock_pipe.incrby.call_args_list - assert len(incrby_calls) == 2 - assert incrby_calls[0].args[1] == 1250 # daily - assert incrby_calls[1].args[1] == 1250 # weekly + await record_cost_usage(_USER, cost_microdollars=5_000) @pytest.mark.asyncio async def test_handles_redis_error_during_pipeline_execute(self): @@ -348,7 +336,7 @@ class TestRecordTokenUsage: return_value=mock_redis, ): # Should not raise — fail-open - await record_token_usage(_USER, prompt_tokens=100, completion_tokens=50) + await record_cost_usage(_USER, cost_microdollars=5_000) # --------------------------------------------------------------------------- @@ -819,7 +807,7 @@ class TestTierLimitsRespected: assert tier == SubscriptionTier.PRO # Should NOT raise — 3M < 12.5M await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @pytest.mark.asyncio @@ -853,7 +841,7 @@ class TestTierLimitsRespected: # Should raise — 2.5M >= 2.5M with pytest.raises(RateLimitExceeded): await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @pytest.mark.asyncio @@ -885,7 +873,7 @@ class TestTierLimitsRespected: assert tier == SubscriptionTier.ENTERPRISE # Should NOT raise — 100M < 150M await check_rate_limit( - _USER, daily_token_limit=daily, weekly_token_limit=weekly + _USER, daily_cost_limit=daily, weekly_cost_limit=weekly ) @@ -912,7 +900,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - result = await reset_daily_usage(_USER, daily_token_limit=10000) + result = await reset_daily_usage(_USER, daily_cost_limit=10000) assert result is True mock_pipe.delete.assert_called_once() @@ -928,7 +916,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=10000) + await reset_daily_usage(_USER, daily_cost_limit=10000) mock_pipe.decrby.assert_called_once() mock_redis.set.assert_not_called() # 35000 > 0, no clamp needed @@ -944,14 +932,14 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=10000) + await reset_daily_usage(_USER, daily_cost_limit=10000) mock_pipe.decrby.assert_called_once() mock_redis.set.assert_called_once() @pytest.mark.asyncio async def test_no_weekly_reduction_when_daily_limit_zero(self): - """When daily_token_limit is 0, weekly counter should not be touched.""" + """When daily_cost_limit is 0, weekly counter should not be touched.""" mock_pipe = self._make_pipeline_mock() mock_pipe.execute = AsyncMock(return_value=[1]) # only delete result mock_redis = AsyncMock() @@ -961,7 +949,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", return_value=mock_redis, ): - await reset_daily_usage(_USER, daily_token_limit=0) + await reset_daily_usage(_USER, daily_cost_limit=0) mock_pipe.delete.assert_called_once() mock_pipe.decrby.assert_not_called() @@ -972,7 +960,7 @@ class TestResetDailyUsage: "backend.copilot.rate_limit.get_redis_async", side_effect=ConnectionError("Redis down"), ): - result = await reset_daily_usage(_USER, daily_token_limit=10000) + result = await reset_daily_usage(_USER, daily_cost_limit=10000) assert result is False diff --git a/autogpt_platform/backend/backend/copilot/reset_usage_test.py b/autogpt_platform/backend/backend/copilot/reset_usage_test.py index cbbf714df0..d5b4ee140e 100644 --- a/autogpt_platform/backend/backend/copilot/reset_usage_test.py +++ b/autogpt_platform/backend/backend/copilot/reset_usage_test.py @@ -16,14 +16,14 @@ from backend.util.exceptions import InsufficientBalanceError # Minimal config mock matching ChatConfig fields used by the endpoint. def _make_config( rate_limit_reset_cost: int = 500, - daily_token_limit: int = 2_500_000, - weekly_token_limit: int = 12_500_000, + daily_cost_limit_microdollars: int = 10_000_000, + weekly_cost_limit_microdollars: int = 50_000_000, max_daily_resets: int = 5, ): cfg = MagicMock() cfg.rate_limit_reset_cost = rate_limit_reset_cost - cfg.daily_token_limit = daily_token_limit - cfg.weekly_token_limit = weekly_token_limit + cfg.daily_cost_limit_microdollars = daily_cost_limit_microdollars + cfg.weekly_cost_limit_microdollars = weekly_cost_limit_microdollars cfg.max_daily_resets = max_daily_resets return cfg @@ -77,10 +77,10 @@ class TestResetCopilotUsage: assert "not available" in exc_info.value.detail async def test_no_daily_limit_returns_400(self): - """When daily_token_limit=0 (unlimited), endpoint returns 400.""" + """When daily_cost_limit=0 (unlimited), endpoint returns 400.""" with ( - patch(f"{_MODULE}.config", _make_config(daily_token_limit=0)), + patch(f"{_MODULE}.config", _make_config(daily_cost_limit_microdollars=0)), patch(f"{_MODULE}.settings", _mock_settings()), _mock_rate_limits(daily=0), ): diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index ea0a135559..e4f29a2b65 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -165,11 +165,6 @@ _MAX_STREAM_ATTEMPTS = 3 # self-correct. The limit is generous to allow recovery attempts. _EMPTY_TOOL_CALL_LIMIT = 5 -# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet -# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus -# turns deplete quota proportionally faster. -_OPUS_COST_MULTIPLIER = 5.0 - # User-facing error shown when the empty-tool-call circuit breaker trips. _CIRCUIT_BREAKER_ERROR_MSG = ( "AutoPilot was unable to complete the tool call " @@ -725,22 +720,20 @@ def _resolve_fallback_model() -> str | None: return _normalize_model_name(raw) -async def _resolve_model_and_multiplier( +async def _resolve_sdk_model_for_request( model: "CopilotLlmModel | None", session_id: str, -) -> tuple[str | None, float]: - """Resolve the SDK model string and rate-limit cost multiplier for a turn. +) -> str | None: + """Resolve the SDK model string for a turn. Priority (highest first): 1. Explicit per-request ``model`` tier from the frontend toggle. 2. Global config default (``_resolve_sdk_model()``). - Returns a ``(sdk_model, cost_multiplier)`` pair. - ``sdk_model`` is ``None`` when the Claude Code subscription default applies. - ``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise. + Returns ``None`` when the Claude Code subscription default applies. + Rate-limit accounting no longer applies a multiplier — the real turn + cost (reported by the SDK) already reflects model-pricing differences. """ - sdk_model = _resolve_sdk_model() - if model == "advanced": sdk_model = _normalize_model_name(config.advanced_model) logger.info( @@ -748,7 +741,7 @@ async def _resolve_model_and_multiplier( session_id[:12] if session_id else "?", sdk_model, ) - return sdk_model, _OPUS_COST_MULTIPLIER + return sdk_model if model == "standard": # Reset to config default — respects subscription mode (None = CLI default). @@ -758,13 +751,9 @@ async def _resolve_model_and_multiplier( session_id[:12] if session_id else "?", sdk_model or "subscription-default", ) - return sdk_model, 1.0 + return sdk_model - # No per-request override; derive multiplier from final resolved model. - cost_multiplier = ( - _OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0 - ) - return sdk_model, cost_multiplier + return _resolve_sdk_model() _MAX_TRANSIENT_BACKOFF_SECONDS = 30 @@ -2895,7 +2884,6 @@ async def stream_chat_completion_sdk( # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None - model_cost_multiplier: float = 1.0 # Make sure there is no more code between the lock acquisition and try-block. try: @@ -3012,10 +3000,8 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server(use_e2b=use_e2b) - # Resolve model and cost multiplier (request tier → config default). - sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier( - model, session_id - ) + # Resolve model (request tier → config default). + sdk_model = await _resolve_sdk_model_for_request(model, session_id) # Track SDK-internal compaction (PreCompact hook → start, next msg → end) compaction = CompactionTracker() @@ -3813,7 +3799,6 @@ async def stream_chat_completion_sdk( cost_usd=turn_cost_usd, model=sdk_model or config.model, provider="anthropic", - model_cost_multiplier=model_cost_multiplier, ) # --- Persist session messages --- diff --git a/autogpt_platform/backend/backend/copilot/token_tracking.py b/autogpt_platform/backend/backend/copilot/token_tracking.py index 19406ced93..f5ace5e749 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking.py @@ -1,9 +1,9 @@ -"""Shared token-usage persistence and rate-limit recording. +"""Shared usage persistence and rate-limit recording. Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to: 1. Append a ``Usage`` record to the session. - 2. Log the turn's token counts. - 3. Record weighted usage in Redis for rate-limiting. + 2. Log the turn's token counts and cost. + 3. Record the real generation cost in Redis for rate-limiting. 4. Write a PlatformCostLog entry for admin cost tracking. This module extracts that common logic so both paths stay in sync. @@ -19,7 +19,7 @@ from backend.data.db_accessors import platform_cost_db from backend.data.platform_cost import PlatformCostEntry, usd_to_microdollars from .model import ChatSession, Usage -from .rate_limit import record_token_usage +from .rate_limit import record_cost_usage logger = logging.getLogger(__name__) @@ -96,9 +96,14 @@ async def persist_and_record_usage( cost_usd: float | str | None = None, model: str | None = None, provider: str = "open_router", - model_cost_multiplier: float = 1.0, ) -> int: - """Persist token usage to session and record for rate limiting. + """Persist token usage to session and record generation cost for rate limiting. + + Rate-limit counters are charged in microdollars against the provider's + reported cost (``cost_usd``), so cache discounts and cross-model pricing + differences are already reflected. When cost is unknown the turn is + logged but the rate-limit counter is left alone — the caller logs an + error at the point the absence is detected. Args: session: The chat session to append usage to (may be None on error). @@ -108,11 +113,11 @@ async def persist_and_record_usage( cache_read_tokens: Tokens served from prompt cache (Anthropic only). cache_creation_tokens: Tokens written to prompt cache (Anthropic only). log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]"). - cost_usd: Optional cost for logging (float from SDK, str otherwise). + cost_usd: Real generation cost for the turn (float from SDK or parsed + from OpenRouter usage.cost). ``None`` means the provider did not + report a cost and rate limiting is skipped for this turn. + model: Model identifier for cost log attribution. provider: Cost provider name (e.g. "anthropic", "open_router"). - model_cost_multiplier: Relative model cost factor for rate limiting - (1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so - more expensive models deplete the rate limit proportionally faster. Returns: The computed total_tokens (prompt + completion; cache excluded). @@ -156,37 +161,51 @@ async def persist_and_record_usage( else: logger.info( f"{log_prefix} Turn usage: prompt={prompt_tokens}, completion={completion_tokens}," - f" total={total_tokens}" + f" total={total_tokens}, cost_usd={cost_usd}" ) - if user_id: + cost_float: float | None = None + if cost_usd is not None: try: - await record_token_usage( - user_id=user_id, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - cache_read_tokens=cache_read_tokens, - cache_creation_tokens=cache_creation_tokens, - model_cost_multiplier=model_cost_multiplier, + val = float(cost_usd) + except (ValueError, TypeError): + logger.error( + "%s cost_usd is not numeric: %r — rate limit skipped", + log_prefix, + cost_usd, ) - except Exception as usage_err: - logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err) + else: + if not math.isfinite(val): + logger.error( + "%s cost_usd is non-finite: %r — rate limit skipped", + log_prefix, + val, + ) + elif val < 0: + logger.warning( + "%s cost_usd %s is negative — skipping rate-limit + cost log", + log_prefix, + val, + ) + else: + cost_float = val + + cost_microdollars = usd_to_microdollars(cost_float) + + if user_id and cost_microdollars is not None and cost_microdollars > 0: + # record_cost_usage() owns its fail-open handling for Redis/network + # errors. Don't wrap with a broad except here — unexpected accounting + # bugs should surface instead of being silently logged as warnings. + await record_cost_usage( + user_id=user_id, + cost_microdollars=cost_microdollars, + ) # Log to PlatformCostLog for admin cost dashboard. # Include entries where cost_usd is set even if token count is 0 # (e.g. fully-cached Anthropic responses where only cache tokens # accumulate a charge without incrementing total_tokens). - if user_id and (total_tokens > 0 or cost_usd is not None): - cost_float = None - if cost_usd is not None: - try: - val = float(cost_usd) - if math.isfinite(val) and val >= 0: - cost_float = val - except (ValueError, TypeError): - pass - - cost_microdollars = usd_to_microdollars(cost_float) + if user_id and (total_tokens > 0 or cost_float is not None): session_id = session.session_id if session else None if cost_float is not None: diff --git a/autogpt_platform/backend/backend/copilot/token_tracking_test.py b/autogpt_platform/backend/backend/copilot/token_tracking_test.py index 11757ce541..ff5957e1f5 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking_test.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking_test.py @@ -37,7 +37,7 @@ class TestTotalTokens: async def test_returns_prompt_plus_completion(self): """total_tokens = prompt + completion (cache excluded from total).""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -63,7 +63,7 @@ class TestTotalTokens: async def test_cache_tokens_excluded_from_total(self): """Cache tokens are stored separately and not added to total_tokens.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -81,7 +81,7 @@ class TestTotalTokens: async def test_baseline_path_no_cache(self): """Baseline (OpenRouter) path passes no cache tokens; total = prompt + completion.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -97,7 +97,7 @@ class TestTotalTokens: async def test_sdk_path_with_cache(self): """SDK (Anthropic) path passes cache tokens; total still = prompt + completion.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -123,7 +123,7 @@ class TestSessionPersistence: async def test_appends_usage_to_session(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -144,7 +144,7 @@ class TestSessionPersistence: async def test_appends_cache_breakdown_to_session(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -163,7 +163,7 @@ class TestSessionPersistence: async def test_multiple_turns_append_multiple_records(self): session = _make_session() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): await persist_and_record_usage( @@ -178,7 +178,7 @@ class TestSessionPersistence: async def test_none_session_does_not_raise(self): """When session is None (e.g. error path), no exception should be raised.""" with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ): total = await persist_and_record_usage( @@ -210,10 +210,11 @@ class TestSessionPersistence: class TestRateLimitRecording: @pytest.mark.asyncio - async def test_calls_record_token_usage_when_user_id_present(self): + async def test_calls_record_cost_usage_when_cost_and_user_id_present(self): + """Rate-limit counter is charged with the real provider cost (microdollars).""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -223,22 +224,35 @@ class TestRateLimitRecording: completion_tokens=50, cache_read_tokens=1000, cache_creation_tokens=200, + cost_usd=0.0123, ) mock_record.assert_awaited_once_with( user_id="user-abc", - prompt_tokens=100, - completion_tokens=50, - cache_read_tokens=1000, - cache_creation_tokens=200, - model_cost_multiplier=1.0, + cost_microdollars=12_300, ) + @pytest.mark.asyncio + async def test_skips_record_when_cost_is_missing(self): + """Without a provider cost we have no authoritative figure to charge.""" + mock_record = AsyncMock() + with patch( + "backend.copilot.token_tracking.record_cost_usage", + new=mock_record, + ): + await persist_and_record_usage( + session=None, + user_id="user-abc", + prompt_tokens=100, + completion_tokens=50, + ) + mock_record.assert_not_awaited() + @pytest.mark.asyncio async def test_skips_record_when_user_id_is_none(self): """Anonymous sessions should not create Redis keys.""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -246,32 +260,38 @@ class TestRateLimitRecording: user_id=None, prompt_tokens=100, completion_tokens=50, + cost_usd=0.001, ) mock_record.assert_not_awaited() @pytest.mark.asyncio - async def test_record_failure_does_not_raise(self): - """A Redis error in record_token_usage should be swallowed (fail-open).""" - mock_record = AsyncMock(side_effect=ConnectionError("Redis down")) + async def test_record_usage_bubbles_unexpected_error(self): + """Unexpected errors from record_cost_usage must propagate. + + record_cost_usage() owns its own (RedisError, ConnectionError, OSError) + fail-open handling. Anything else is a real accounting bug and + should not be silently swallowed at this layer. + """ + mock_record = AsyncMock(side_effect=RuntimeError("boom")) with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): - # Should not raise - total = await persist_and_record_usage( - session=None, - user_id="user-xyz", - prompt_tokens=100, - completion_tokens=50, - ) - assert total == 150 + with pytest.raises(RuntimeError, match="boom"): + await persist_and_record_usage( + session=None, + user_id="user-xyz", + prompt_tokens=100, + completion_tokens=50, + cost_usd=0.002, + ) @pytest.mark.asyncio - async def test_skips_record_when_zero_tokens(self): - """Returns 0 before calling record_token_usage when tokens are zero.""" + async def test_skips_record_when_zero_tokens_and_no_cost(self): + """Returns 0 before calling record_cost_usage when there is nothing to record.""" mock_record = AsyncMock() with patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new=mock_record, ): await persist_and_record_usage( @@ -295,7 +315,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -336,7 +356,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -369,7 +389,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -394,7 +414,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -423,7 +443,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -452,7 +472,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -479,7 +499,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -509,7 +529,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( @@ -545,7 +565,7 @@ class TestPlatformCostLogging: mock_log = AsyncMock() with ( patch( - "backend.copilot.token_tracking.record_token_usage", + "backend.copilot.token_tracking.record_cost_usage", new_callable=AsyncMock, ), patch( diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index c341666cdb..1e29ff4102 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -42,8 +42,8 @@ class Flag(str, Enum): CHAT = "chat" CHAT_MODE_OPTION = "chat-mode-option" COPILOT_SDK = "copilot-sdk" - COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit" - COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit" + COPILOT_DAILY_COST_LIMIT = "copilot-daily-cost-limit-microdollars" + COPILOT_WEEKLY_COST_LIMIT = "copilot-weekly-cost-limit-microdollars" STRIPE_PRICE_PRO = "stripe-price-id-pro" STRIPE_PRICE_BUSINESS = "stripe-price-id-business" GRAPHITI_MEMORY = "graphiti-memory" diff --git a/autogpt_platform/backend/snapshots/get_rate_limit b/autogpt_platform/backend/snapshots/get_rate_limit index 5bae448ba2..3ac1b94222 100644 --- a/autogpt_platform/backend/snapshots/get_rate_limit +++ b/autogpt_platform/backend/snapshots/get_rate_limit @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 500000, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 500000, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 3000000 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 3000000 } diff --git a/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly b/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly index c73be30be5..b5361be34a 100644 --- a/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly +++ b/autogpt_platform/backend/snapshots/reset_user_usage_daily_and_weekly @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 0, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 0, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 0 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 0 } diff --git a/autogpt_platform/backend/snapshots/reset_user_usage_daily_only b/autogpt_platform/backend/snapshots/reset_user_usage_daily_only index 5b205a8bfb..256d8e893d 100644 --- a/autogpt_platform/backend/snapshots/reset_user_usage_daily_only +++ b/autogpt_platform/backend/snapshots/reset_user_usage_daily_only @@ -1,9 +1,9 @@ { - "daily_token_limit": 2500000, - "daily_tokens_used": 0, + "daily_cost_limit_microdollars": 2500000, + "daily_cost_used_microdollars": 0, "tier": "FREE", "user_email": "target@example.com", "user_id": "5e53486c-cf57-477e-ba2a-cb02dc828e1c", - "weekly_token_limit": 12500000, - "weekly_tokens_used": 3000000 + "weekly_cost_limit_microdollars": 12500000, + "weekly_cost_used_microdollars": 3000000 } diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx index de95cf0e47..442ebf43bc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/components/UsageBar.tsx @@ -1,10 +1,6 @@ "use client"; -export function formatTokens(tokens: number): string { - if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`; - if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(0)}K`; - return tokens.toString(); -} +import { formatMicrodollarsAsUsd } from "@/app/(platform)/copilot/components/usageHelpers"; export function UsageBar({ used, limit }: { used: number; limit: number }) { if (limit === 0) { @@ -17,8 +13,8 @@ export function UsageBar({ used, limit }: { used: number; limit: number }) { return (
- {formatTokens(used)} used - {formatTokens(limit)} limit + {formatMicrodollarsAsUsd(used)} spent + {formatMicrodollarsAsUsd(limit)} limit
{ + it('renders "Unlimited" when limit is 0', () => { + render(); + expect(screen.getByText("Unlimited")).toBeDefined(); + }); + + it("renders spent + limit in USD", () => { + render(); + expect(screen.getByText("$1.50 spent")).toBeDefined(); + expect(screen.getByText("$10.00 limit")).toBeDefined(); + }); + + it("renders the computed percentage", () => { + render(); + expect(screen.getByText("50.0% used")).toBeDefined(); + }); + + it("clamps percentage at 100% when over limit", () => { + render(); + expect(screen.getByText("100.0% used")).toBeDefined(); + }); + + it("clamps percentage at 0% for negative used", () => { + render(); + expect(screen.getByText("0.0% used")).toBeDefined(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx index b216745c35..024b819699 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/RateLimitDisplay.tsx @@ -88,8 +88,9 @@ export function RateLimitDisplay({ } const nothingToReset = resetWeekly - ? data.daily_tokens_used === 0 && data.weekly_tokens_used === 0 - : data.daily_tokens_used === 0; + ? data.daily_cost_used_microdollars === 0 && + data.weekly_cost_used_microdollars === 0 + : data.daily_cost_used_microdollars === 0; return (
@@ -133,17 +134,17 @@ export function RateLimitDisplay({
-

Daily Usage

+

Daily Spend

-

Weekly Usage

+

Weekly Spend

diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx index 5425a14ff2..08b5db312b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitDisplay.test.tsx @@ -30,10 +30,10 @@ function makeData( return { user_id: "user-abc-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", ...overrides, }; @@ -113,8 +113,8 @@ describe("RateLimitDisplay", () => { it("renders daily and weekly usage sections", () => { render(); - expect(screen.getByText("Daily Usage")).toBeDefined(); - expect(screen.getByText("Weekly Usage")).toBeDefined(); + expect(screen.getByText("Daily Spend")).toBeDefined(); + expect(screen.getByText("Weekly Spend")).toBeDefined(); }); it("renders reset scope dropdown and reset button", () => { @@ -126,7 +126,7 @@ describe("RateLimitDisplay", () => { it("disables reset button when nothing to reset", () => { render( , ); @@ -137,7 +137,7 @@ describe("RateLimitDisplay", () => { it("enables reset button when there is usage to reset", () => { render( , ); @@ -174,7 +174,7 @@ describe("RateLimitDisplay", () => { render( , ); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx index ab996748f1..8435e6dc6d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/RateLimitManager.test.tsx @@ -174,10 +174,10 @@ describe("RateLimitManager", () => { rateLimitData: { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", }, }); @@ -197,10 +197,10 @@ describe("RateLimitManager", () => { rateLimitData: { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", }, }); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts index d09a74b507..523af7514b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/admin/rate-limits/components/__tests__/useRateLimitManager.test.ts @@ -28,10 +28,10 @@ function makeRateLimitResponse(overrides = {}) { return { user_id: "user-123", user_email: "alice@example.com", - daily_token_limit: 10000, - weekly_token_limit: 50000, - daily_tokens_used: 2500, - weekly_tokens_used: 10000, + daily_cost_limit_microdollars: 10_000_000, + weekly_cost_limit_microdollars: 50_000_000, + daily_cost_used_microdollars: 2_500_000, + weekly_cost_used_microdollars: 10_000_000, tier: "FREE", ...overrides, }; @@ -229,8 +229,12 @@ describe("useRateLimitManager", () => { }); it("handleReset calls reset endpoint and updates data", async () => { - const initial = makeRateLimitResponse({ daily_tokens_used: 5000 }); - const after = makeRateLimitResponse({ daily_tokens_used: 0 }); + const initial = makeRateLimitResponse({ + daily_cost_used_microdollars: 5_000_000, + }); + const after = makeRateLimitResponse({ + daily_cost_used_microdollars: 0, + }); mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial }); mockPostV2ResetUserRateLimitUsage.mockResolvedValue({ status: 200, @@ -338,7 +342,9 @@ describe("useRateLimitManager", () => { }); it("handleReset throws when endpoint returns non-200 status", async () => { - const initial = makeRateLimitResponse({ daily_tokens_used: 5000 }); + const initial = makeRateLimitResponse({ + daily_cost_used_microdollars: 5_000_000, + }); mockGetV2GetUserRateLimit.mockResolvedValue({ status: 200, data: initial }); mockPostV2ResetUserRateLimitUsage.mockResolvedValue({ status: 500 }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx index 158d0b2392..c3ac603073 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/CopilotPage.tsx @@ -1,6 +1,6 @@ "use client"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { toast } from "@/components/molecules/Toast/use-toast"; import useCredits from "@/hooks/useCredits"; @@ -125,7 +125,7 @@ export function CopilotPage() { isError: usageError, } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -258,9 +258,7 @@ export function CopilotPage() { resetCost={resetCost ?? 0} resetMessage={rateLimitMessage ?? ""} isWeeklyExhausted={ - hasUsage && - usage.weekly.limit > 0 && - usage.weekly.used >= usage.weekly.limit + hasUsage && !!usage.weekly && usage.weekly.percent_used >= 100 } hasInsufficientCredits={hasInsufficientCredits} isBillingEnabled={isBillingEnabled} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx index 71791b5694..bef9a2a848 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/CopilotPage.test.tsx @@ -39,13 +39,23 @@ vi.mock("@/components/ui/sidebar", () => ({ ), })); -// Mock hooks that hit the network +// Mock hooks that hit the network. Exercise the `select` callback so its +// line counts as covered alongside the rest of the options. vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ - useGetV2GetCopilotUsage: () => ({ - data: undefined, - isSuccess: false, - isError: false, - }), + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const data = { + daily: null, + weekly: null, + tier: "FREE", + reset_cost: 0, + }; + if (typeof opts?.query?.select === "function") { + opts.query.select({ data }); + } + return { data: undefined, isSuccess: false, isError: false }; + }, })); vi.mock("@/hooks/useCredits", () => ({ default: () => ({ credits: null, fetchCredits: vi.fn() }), diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx index 1420e626b3..711c36c26e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsageLimits.tsx @@ -1,4 +1,4 @@ -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import useCredits from "@/hooks/useCredits"; import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; @@ -14,9 +14,9 @@ import { UsagePanelContent } from "./UsagePanelContent"; export { UsagePanelContent, formatResetTime } from "./UsagePanelContent"; export function UsageLimits() { - const { data: usage, isLoading } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -28,8 +28,8 @@ export function UsageLimits() { const hasInsufficientCredits = credits !== null && resetCost != null && credits < resetCost; - if (isLoading || !usage?.daily || !usage?.weekly) return null; - if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return ( diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx index 91187816da..9a1c0d1c87 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/UsagePanelContent.tsx @@ -1,4 +1,4 @@ -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { Button } from "@/components/atoms/Button/Button"; import Link from "next/link"; import { formatCents, formatResetTime } from "../usageHelpers"; @@ -8,22 +8,17 @@ export { formatResetTime }; function UsageBar({ label, - used, - limit, + percentUsed, resetsAt, }: { label: string; - used: number; - limit: number; + percentUsed: number; resetsAt: Date | string; }) { - if (limit <= 0) return null; - - const rawPercent = (used / limit) * 100; - const percent = Math.min(100, Math.round(rawPercent)); + const percent = Math.min(100, Math.max(0, Math.round(percentUsed))); const isHigh = percent >= 80; const percentLabel = - used > 0 && percent === 0 ? "<1% used" : `${percent}% used`; + percentUsed > 0 && percent === 0 ? "<1% used" : `${percent}% used`; return (
@@ -38,10 +33,15 @@ function UsageBar({
0 ? 1 : 0, percent)}%` }} + style={{ width: `${Math.max(percent > 0 ? 1 : 0, percent)}%` }} />
@@ -79,21 +79,19 @@ export function UsagePanelContent({ isBillingEnabled = false, onCreditChange, }: { - usage: CoPilotUsageStatus; + usage: CoPilotUsagePublic; showBillingLink?: boolean; hasInsufficientCredits?: boolean; isBillingEnabled?: boolean; onCreditChange?: () => void; }) { - const hasDailyLimit = usage.daily.limit > 0; - const hasWeeklyLimit = usage.weekly.limit > 0; - const isDailyExhausted = - hasDailyLimit && usage.daily.used >= usage.daily.limit; - const isWeeklyExhausted = - hasWeeklyLimit && usage.weekly.used >= usage.weekly.limit; + const daily = usage.daily; + const weekly = usage.weekly; + const isDailyExhausted = !!daily && daily.percent_used >= 100; + const isWeeklyExhausted = !!weekly && weekly.percent_used >= 100; const resetCost = usage.reset_cost ?? 0; - if (!hasDailyLimit && !hasWeeklyLimit) { + if (!daily && !weekly) { return (
No usage limits configured
); @@ -113,20 +111,18 @@ export function UsagePanelContent({ {tierLabel} plan )}
- {hasDailyLimit && ( + {daily && ( )} - {hasWeeklyLimit && ( + {weekly && ( )} {isDailyExhausted && diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx index 9c7a78599f..67595dceec 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsageLimits.test.tsx @@ -2,10 +2,19 @@ import { render, screen, cleanup } from "@/tests/integrations/test-utils"; import { afterEach, describe, expect, it, vi } from "vitest"; import { UsageLimits } from "../UsageLimits"; -// Mock the generated Orval hook +// Mock the generated Orval hook, exercising the `select` callback so its +// line counts as covered alongside the rest of the options. const mockUseGetV2GetCopilotUsage = vi.fn(); vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ - useGetV2GetCopilotUsage: (opts: unknown) => mockUseGetV2GetCopilotUsage(opts), + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const ret = mockUseGetV2GetCopilotUsage(opts) as { data?: unknown }; + if (ret?.data !== undefined && typeof opts?.query?.select === "function") { + opts.query.select({ data: ret.data }); + } + return ret; + }, })); // Mock Popover to render children directly (Radix portals don't work in happy-dom) @@ -27,22 +36,24 @@ afterEach(() => { }); function makeUsage({ - dailyUsed = 500, - dailyLimit = 10000, - weeklyUsed = 2000, - weeklyLimit = 50000, + dailyPercent = 5, + weeklyPercent = 4, tier = "FREE", }: { - dailyUsed?: number; - dailyLimit?: number; - weeklyUsed?: number; - weeklyLimit?: number; + dailyPercent?: number | null; + weeklyPercent?: number | null; tier?: string; } = {}) { - const future = new Date(Date.now() + 3600 * 1000); // 1h from now + const future = new Date(Date.now() + 3600 * 1000).toISOString(); return { - daily: { used: dailyUsed, limit: dailyLimit, resets_at: future }, - weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future }, + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, tier, }; } @@ -51,7 +62,7 @@ describe("UsageLimits", () => { it("renders nothing while loading", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: undefined, - isLoading: true, + isSuccess: false, }); const { container } = render(); expect(container.innerHTML).toBe(""); @@ -59,8 +70,8 @@ describe("UsageLimits", () => { it("renders nothing when no limits are configured", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyLimit: 0, weeklyLimit: 0 }), - isLoading: false, + data: makeUsage({ dailyPercent: null, weeklyPercent: null }), + isSuccess: true, }); const { container } = render(); expect(container.innerHTML).toBe(""); @@ -69,16 +80,16 @@ describe("UsageLimits", () => { it("renders the usage button when limits exist", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage(), - isLoading: false, + isSuccess: true, }); render(); expect(screen.getByRole("button", { name: /usage limits/i })).toBeDefined(); }); - it("displays daily and weekly usage percentages", () => { + it("displays daily and weekly percentage", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyUsed: 5000, dailyLimit: 10000 }), - isLoading: false, + data: makeUsage({ dailyPercent: 50, weeklyPercent: 4 }), + isSuccess: true, }); render(); @@ -88,14 +99,10 @@ describe("UsageLimits", () => { expect(screen.getByText("Usage limits")).toBeDefined(); }); - it("shows only weekly bar when daily limit is 0", () => { + it("shows only weekly bar when daily is null", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ - dailyLimit: 0, - weeklyUsed: 25000, - weeklyLimit: 50000, - }), - isLoading: false, + data: makeUsage({ dailyPercent: null, weeklyPercent: 50 }), + isSuccess: true, }); render(); @@ -103,20 +110,22 @@ describe("UsageLimits", () => { expect(screen.queryByText("Today")).toBeNull(); }); - it("caps percentage at 100% when over limit", () => { + it("caps bar width at 100% when over limit", () => { + // 150% exercises the clamp — 100% exactly is merely exhausted, not over. mockUseGetV2GetCopilotUsage.mockReturnValue({ - data: makeUsage({ dailyUsed: 15000, dailyLimit: 10000 }), - isLoading: false, + data: makeUsage({ dailyPercent: 150 }), + isSuccess: true, }); render(); - expect(screen.getByText("100% used")).toBeDefined(); + const dailyBar = screen.getByRole("progressbar", { name: /today usage/i }); + expect(dailyBar.getAttribute("aria-valuenow")).toBe("100"); }); it("displays the user tier label", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage({ tier: "PRO" }), - isLoading: false, + isSuccess: true, }); render(); @@ -126,7 +135,7 @@ describe("UsageLimits", () => { it("shows learn more link to credits page", () => { mockUseGetV2GetCopilotUsage.mockReturnValue({ data: makeUsage(), - isLoading: false, + isSuccess: true, }); render(); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx index 9230663381..db2d4241a8 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/UsageLimits/__tests__/UsagePanelContentRender.test.tsx @@ -6,7 +6,7 @@ import { } from "@/tests/integrations/test-utils"; import { afterEach, describe, expect, it, vi } from "vitest"; import { UsagePanelContent } from "../UsagePanelContent"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; const mockResetUsage = vi.fn(); vi.mock("../../../hooks/useResetRateLimit", () => ({ @@ -20,36 +20,38 @@ afterEach(() => { function makeUsage( overrides: Partial<{ - dailyUsed: number; - dailyLimit: number; - weeklyUsed: number; - weeklyLimit: number; + dailyPercent: number | null; + weeklyPercent: number | null; tier: string; resetCost: number; }> = {}, -): CoPilotUsageStatus { +): CoPilotUsagePublic { const { - dailyUsed = 500, - dailyLimit = 10000, - weeklyUsed = 2000, - weeklyLimit = 50000, + dailyPercent = 5, + weeklyPercent = 4, tier = "FREE", resetCost = 100, } = overrides; - const future = new Date(Date.now() + 3600 * 1000); + const future = new Date(Date.now() + 3600 * 1000).toISOString(); return { - daily: { used: dailyUsed, limit: dailyLimit, resets_at: future }, - weekly: { used: weeklyUsed, limit: weeklyLimit, resets_at: future }, + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, tier, reset_cost: resetCost, - } as CoPilotUsageStatus; + } as CoPilotUsagePublic; } describe("UsagePanelContent", () => { - it("renders 'No usage limits configured' when both limits are zero", () => { + it("renders 'No usage limits configured' when both windows are null", () => { render( , ); expect(screen.getByText("No usage limits configured")).toBeDefined(); @@ -58,11 +60,7 @@ describe("UsagePanelContent", () => { it("renders the reset button when daily limit is exhausted", () => { render( , ); expect(screen.getByText(/Reset daily limit/)).toBeDefined(); @@ -72,10 +70,8 @@ describe("UsagePanelContent", () => { render( , @@ -86,11 +82,7 @@ describe("UsagePanelContent", () => { it("calls resetUsage when the reset button is clicked", () => { render( , ); fireEvent.click(screen.getByText(/Reset daily limit/)); @@ -100,15 +92,21 @@ describe("UsagePanelContent", () => { it("renders 'Add credits' link when insufficient credits", () => { render( , ); expect(screen.getByText("Add credits to reset")).toBeDefined(); }); + + it("renders percent used in the usage bar", () => { + render(); + expect(screen.getByText("25% used")).toBeDefined(); + }); + + it("renders '<1% used' when usage is greater than 0 but rounds to 0", () => { + render(); + expect(screen.getByText("<1% used")).toBeDefined(); + }); }); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts new file mode 100644 index 0000000000..eecdb70245 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/__tests__/usageHelpers.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it } from "vitest"; +import { + formatCents, + formatMicrodollarsAsUsd, + formatResetTime, +} from "../usageHelpers"; + +describe("formatCents", () => { + it("formats whole dollars", () => { + expect(formatCents(500)).toBe("$5.00"); + }); + + it("formats zero", () => { + expect(formatCents(0)).toBe("$0.00"); + }); + + it("formats fractional cents", () => { + expect(formatCents(1999)).toBe("$19.99"); + }); +}); + +describe("formatMicrodollarsAsUsd", () => { + it("formats zero as $0.00", () => { + expect(formatMicrodollarsAsUsd(0)).toBe("$0.00"); + }); + + it("formats whole dollar amounts", () => { + expect(formatMicrodollarsAsUsd(1_500_000)).toBe("$1.50"); + }); + + it("formats amounts that round to $0.00 but are > 0 as <$0.01", () => { + expect(formatMicrodollarsAsUsd(999)).toBe("<$0.01"); + }); + + it("formats exactly one cent as $0.01", () => { + expect(formatMicrodollarsAsUsd(10_000)).toBe("$0.01"); + }); + + it("formats negative input with toFixed semantics (no special case)", () => { + // Negative should never come from the backend, but the helper is + // safe — it simply passes through `toFixed`. + expect(formatMicrodollarsAsUsd(-1_500_000)).toBe("$-1.50"); + }); + + it("formats very large values without truncating", () => { + expect(formatMicrodollarsAsUsd(1_234_567_890)).toBe("$1234.57"); + }); +}); + +describe("formatResetTime", () => { + it("returns 'now' when reset time is in the past", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const past = new Date("2026-04-21T11:59:00Z"); + expect(formatResetTime(past, now)).toBe("now"); + }); + + it("renders sub-hour resets as minutes", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-21T12:15:00Z"); + expect(formatResetTime(future, now)).toBe("in 15m"); + }); + + it("renders same-day resets as 'Xh Ym'", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-21T14:30:00Z"); + expect(formatResetTime(future, now)).toBe("in 2h 30m"); + }); + + it("renders future-day resets as a localized date string", () => { + const now = new Date("2026-04-21T12:00:00Z"); + const future = new Date("2026-04-24T12:00:00Z"); + // Not asserting exact format (localized), just that it's not the + // minute/hour form. + expect(formatResetTime(future, now)).not.toMatch(/^in \d/); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts index 599442075f..f25df85e9b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/usageHelpers.ts @@ -2,6 +2,12 @@ export function formatCents(cents: number): string { return `$${(cents / 100).toFixed(2)}`; } +export function formatMicrodollarsAsUsd(microdollars: number): string { + const dollars = microdollars / 1_000_000; + if (microdollars > 0 && dollars < 0.01) return "<$0.01"; + return `$${dollars.toFixed(2)}`; +} + export function formatResetTime( resetsAt: Date | string, now: Date = new Date(), diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx index 939ec5403f..fc6e26424d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/BriefingTabContent.tsx @@ -1,6 +1,6 @@ "use client"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import type { LibraryAgent } from "@/app/api/__generated__/models/libraryAgent"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { @@ -42,9 +42,9 @@ export function BriefingTabContent({ activeTab, agents }: Props) { } function UsageSection() { - const { data: usage } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, @@ -56,7 +56,8 @@ function UsageSection() { const hasInsufficientCredits = credits !== null && resetCost != null && credits < resetCost; - if (!usage?.daily || !usage?.weekly) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return (
@@ -80,19 +81,17 @@ function UsageSection() { )}
- {usage.daily.limit > 0 && ( + {usage.daily && ( )} - {usage.weekly.limit > 0 && ( + {usage.weekly && ( )} @@ -244,14 +243,12 @@ function UsageFooter({ hasInsufficientCredits, onCreditChange, }: { - usage: CoPilotUsageStatus; + usage: CoPilotUsagePublic; hasInsufficientCredits: boolean; onCreditChange?: () => void; }) { - const isDailyExhausted = - usage.daily.limit > 0 && usage.daily.used >= usage.daily.limit; - const isWeeklyExhausted = - usage.weekly.limit > 0 && usage.weekly.used >= usage.weekly.limit; + const isDailyExhausted = !!usage.daily && usage.daily.percent_used >= 100; + const isWeeklyExhausted = !!usage.weekly && usage.weekly.percent_used >= 100; const resetCost = usage.reset_cost ?? 0; const { resetUsage, isPending } = useResetRateLimit({ onCreditChange }); @@ -294,22 +291,17 @@ function UsageFooter({ function UsageMeter({ label, - used, - limit, + percentUsed, resetsAt, }: { label: string; - used: number; - limit: number; + percentUsed: number; resetsAt: Date | string; }) { - if (limit <= 0) return null; - - const rawPercent = (used / limit) * 100; - const percent = Math.min(100, Math.round(rawPercent)); + const percent = Math.min(100, Math.max(0, Math.round(percentUsed))); const isHigh = percent >= 80; const percentLabel = - used > 0 && percent === 0 ? "<1% used" : `${percent}% used`; + percentUsed > 0 && percent === 0 ? "<1% used" : `${percent}% used`; return (
@@ -323,20 +315,20 @@ function UsageMeter({
0 ? 1 : 0, percent)}%` }} + style={{ width: `${Math.max(percent > 0 ? 1 : 0, percent)}%` }} />
-
- - {used.toLocaleString()} / {limit.toLocaleString()} - - - Resets {formatResetTime(resetsAt)} - -
+ + Resets {formatResetTime(resetsAt)} +
); } diff --git a/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx new file mode 100644 index 0000000000..5dbb3bab17 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/library/components/AgentBriefingPanel/__tests__/BriefingTabContent.test.tsx @@ -0,0 +1,212 @@ +import { render, screen, cleanup } from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { BriefingTabContent } from "../BriefingTabContent"; + +const mockUseGetV2GetCopilotUsage = vi.fn(); +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + useGetV2GetCopilotUsage: (opts: { + query?: { select?: (r: { data: unknown }) => unknown }; + }) => { + const ret = mockUseGetV2GetCopilotUsage(opts) as { data?: unknown }; + // Exercise the `select` callback so its line counts as covered. + if (ret?.data !== undefined && typeof opts?.query?.select === "function") { + opts.query.select({ data: ret.data }); + } + return ret; + }, +})); + +const mockUseGetFlag = vi.fn(); +vi.mock("@/services/feature-flags/use-get-flag", async () => { + const actual = await vi.importActual< + typeof import("@/services/feature-flags/use-get-flag") + >("@/services/feature-flags/use-get-flag"); + return { + ...actual, + useGetFlag: (flag: unknown) => mockUseGetFlag(flag), + }; +}); + +const mockUseCredits = vi.fn(); +vi.mock("@/hooks/useCredits", () => ({ + default: (opts: unknown) => mockUseCredits(opts), +})); + +const mockResetUsage = vi.fn(); +vi.mock("@/app/(platform)/copilot/hooks/useResetRateLimit", () => ({ + useResetRateLimit: () => ({ + resetUsage: mockResetUsage, + isPending: false, + }), +})); + +afterEach(() => { + cleanup(); + mockUseGetV2GetCopilotUsage.mockReset(); + mockUseGetFlag.mockReset(); + mockUseCredits.mockReset(); + mockResetUsage.mockReset(); +}); + +function makeUsage({ + dailyPercent = 5, + weeklyPercent = 4, + tier = "FREE", + resetCost = 500, +}: { + dailyPercent?: number | null; + weeklyPercent?: number | null; + tier?: string; + resetCost?: number; +} = {}) { + const future = new Date(Date.now() + 3600 * 1000).toISOString(); + return { + daily: + dailyPercent === null + ? null + : { percent_used: dailyPercent, resets_at: future }, + weekly: + weeklyPercent === null + ? null + : { percent_used: weeklyPercent, resets_at: future }, + tier, + reset_cost: resetCost, + }; +} + +describe("BriefingTabContent — UsageSection", () => { + it("renders nothing when usage fetch has not succeeded", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + const { container } = render( + , + ); + expect(container.innerHTML).toBe(""); + }); + + it("renders nothing when both windows are null (no limits configured)", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: null, weeklyPercent: null }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + const { container } = render( + , + ); + expect(container.innerHTML).toBe(""); + }); + + it("renders tier badge + daily+weekly meters at normal usage", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 12, weeklyPercent: 4, tier: "PRO" }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("Usage limits")).toBeDefined(); + expect(screen.getByText("Pro plan")).toBeDefined(); + expect(screen.getByText("12% used")).toBeDefined(); + expect(screen.getByText("4% used")).toBeDefined(); + expect(screen.getByText("Today")).toBeDefined(); + expect(screen.getByText("This week")).toBeDefined(); + expect(screen.getByText("Manage billing")).toBeDefined(); + }); + + it("shows reset button when daily limit is exhausted and user has credits", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 100, weeklyPercent: 40, resetCost: 500 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText(/Reset daily limit/)).toBeDefined(); + }); + + it("shows 'Add credits' CTA when daily exhausted but user lacks credits", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 100, weeklyPercent: 40, resetCost: 500 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 10, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("Add credits to reset")).toBeDefined(); + expect(screen.queryByText(/Reset daily limit/)).toBeNull(); + }); + + it("hides reset CTAs when the weekly limit is also exhausted", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ + dailyPercent: 100, + weeklyPercent: 100, + resetCost: 500, + }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(true); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.queryByText(/Reset daily limit/)).toBeNull(); + expect(screen.queryByText("Add credits to reset")).toBeNull(); + }); + + it("renders <1% used when percent is >0 but rounds to 0", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: makeUsage({ dailyPercent: 0.4, weeklyPercent: 0 }), + isSuccess: true, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + render(); + + expect(screen.getByText("<1% used")).toBeDefined(); + }); + + it("dispatches to ExecutionListSection for running/attention/completed tabs", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + + for (const tab of ["running", "attention", "completed"] as const) { + const { unmount } = render( + , + ); + // Empty list -> EmptyMessage renders for each of the execution tabs. + expect( + screen.getByText(/No agents|No recently completed/i), + ).toBeDefined(); + unmount(); + } + }); + + it("dispatches to AgentListSection for listening/scheduled/idle tabs", () => { + mockUseGetV2GetCopilotUsage.mockReturnValue({ + data: undefined, + isSuccess: false, + }); + mockUseGetFlag.mockReturnValue(false); + mockUseCredits.mockReturnValue({ credits: 1000, fetchCredits: vi.fn() }); + + for (const tab of ["listening", "scheduled", "idle"] as const) { + const { unmount } = render( + , + ); + expect(screen.getByText(/No/i)).toBeDefined(); + unmount(); + } + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx index fb565c048b..f6f9398721 100644 --- a/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/profile/(user)/credits/page.tsx @@ -13,7 +13,7 @@ import { RefundModal } from "./RefundModal"; import { SubscriptionTierSection } from "./components/SubscriptionTierSection/SubscriptionTierSection"; import { CreditTransaction } from "@/lib/autogpt-server-api"; import { UsagePanelContent } from "@/app/(platform)/copilot/components/UsageLimits/UsageLimits"; -import type { CoPilotUsageStatus } from "@/app/api/__generated__/models/coPilotUsageStatus"; +import type { CoPilotUsagePublic } from "@/app/api/__generated__/models/coPilotUsagePublic"; import { useGetV2GetCopilotUsage } from "@/app/api/__generated__/endpoints/chat/chat"; import { @@ -27,16 +27,16 @@ import { function CoPilotUsageSection() { const router = useRouter(); - const { data: usage, isLoading } = useGetV2GetCopilotUsage({ + const { data: usage, isSuccess } = useGetV2GetCopilotUsage({ query: { - select: (res) => res.data as CoPilotUsageStatus, + select: (res) => res.data as CoPilotUsagePublic, refetchInterval: 30000, staleTime: 10000, }, }); - if (isLoading || !usage?.daily || !usage?.weekly) return null; - if (usage.daily.limit <= 0 && usage.weekly.limit <= 0) return null; + if (!isSuccess || !usage) return null; + if (!usage.daily && !usage.weekly) return null; return (
diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index f20f34a805..9103d6f475 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1793,7 +1793,7 @@ } }, "429": { - "description": "Token rate-limit or call-frequency cap exceeded" + "description": "Cost rate-limit or call-frequency cap exceeded" } } } @@ -1879,14 +1879,14 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Copilot Usage", - "description": "Get CoPilot usage status for the authenticated user.\n\nReturns current token usage vs limits for daily and weekly windows.\nGlobal defaults sourced from LaunchDarkly (falling back to config).\nIncludes the user's rate-limit tier.", + "description": "Get CoPilot usage status for the authenticated user.\n\nReturns the percentage of the daily/weekly allowance used — not the\nraw spend or cap — so clients cannot derive per-turn cost or platform\nmargins. Global defaults sourced from LaunchDarkly (falling back to\nconfig). Includes the user's rate-limit tier.", "operationId": "getV2GetCopilotUsage", "responses": { "200": { "description": "Successful Response", "content": { "application/json": { - "schema": { "$ref": "#/components/schemas/CoPilotUsageStatus" } + "schema": { "$ref": "#/components/schemas/CoPilotUsagePublic" } } } }, @@ -1901,7 +1901,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Reset Copilot Usage", - "description": "Reset the daily CoPilot rate limit by spending credits.\n\nAllows users who have hit their daily token limit to spend credits\nto reset their daily usage counter and continue working.\nReturns 400 if the feature is disabled or the user is not over the limit.\nReturns 402 if the user has insufficient credits.", + "description": "Reset the daily CoPilot rate limit by spending credits.\n\nAllows users who have hit their daily cost limit to spend credits\nto reset their daily usage counter and continue working.\nReturns 400 if the feature is disabled or the user is not over the limit.\nReturns 402 if the user has insufficient credits.", "operationId": "postV2ResetCopilotUsage", "responses": { "200": { @@ -9211,10 +9211,22 @@ "title": "ClarifyingQuestion", "description": "A question that needs user clarification." }, - "CoPilotUsageStatus": { + "CoPilotUsagePublic": { "properties": { - "daily": { "$ref": "#/components/schemas/UsageWindow" }, - "weekly": { "$ref": "#/components/schemas/UsageWindow" }, + "daily": { + "anyOf": [ + { "$ref": "#/components/schemas/UsageWindowPublic" }, + { "type": "null" } + ], + "description": "Null when no daily cap is configured (unlimited)." + }, + "weekly": { + "anyOf": [ + { "$ref": "#/components/schemas/UsageWindowPublic" }, + { "type": "null" } + ], + "description": "Null when no weekly cap is configured (unlimited)." + }, "tier": { "$ref": "#/components/schemas/SubscriptionTier", "default": "FREE" @@ -9227,9 +9239,8 @@ } }, "type": "object", - "required": ["daily", "weekly"], - "title": "CoPilotUsageStatus", - "description": "Current usage status for a user across all windows." + "title": "CoPilotUsagePublic", + "description": "Current usage status for a user — public (client-safe) shape." }, "ContentType": { "type": "string", @@ -12997,8 +13008,8 @@ "description": "Credit balance after charge (in cents)" }, "usage": { - "$ref": "#/components/schemas/CoPilotUsageStatus", - "description": "Updated usage status after reset" + "$ref": "#/components/schemas/CoPilotUsagePublic", + "description": "Updated usage status after reset (percentages only)" } }, "type": "object", @@ -14259,7 +14270,7 @@ "type": "string", "enum": ["FREE", "PRO", "BUSINESS", "ENTERPRISE"], "title": "SubscriptionTier", - "description": "Subscription tiers with increasing token allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier" + "description": "Subscription tiers with increasing cost allowances.\n\nMirrors the ``SubscriptionTier`` enum in ``schema.prisma``.\nOnce ``prisma generate`` is run, this can be replaced with::\n\n from prisma.enums import SubscriptionTier" }, "SubscriptionTierRequest": { "properties": { @@ -15886,13 +15897,14 @@ "required": ["timezone"], "title": "UpdateTimezoneRequest" }, - "UsageWindow": { + "UsageWindowPublic": { "properties": { - "used": { "type": "integer", "title": "Used" }, - "limit": { - "type": "integer", - "title": "Limit", - "description": "Maximum tokens allowed in this window. 0 means unlimited." + "percent_used": { + "type": "number", + "maximum": 100.0, + "minimum": 0.0, + "title": "Percent Used", + "description": "Percentage of the window's allowance used (0-100). Clamped at 100 when over the cap." }, "resets_at": { "type": "string", @@ -15901,9 +15913,9 @@ } }, "type": "object", - "required": ["used", "limit", "resets_at"], - "title": "UsageWindow", - "description": "Usage within a single time window." + "required": ["percent_used", "resets_at"], + "title": "UsageWindowPublic", + "description": "Public view of a usage window — only the percentage and reset time.\n\nHides the raw spend and the cap so clients cannot derive per-turn cost\nor reverse-engineer platform margins. ``percent_used`` is capped at 100." }, "UserCostSummary": { "properties": { @@ -16144,31 +16156,31 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "User Email" }, - "daily_token_limit": { + "daily_cost_limit_microdollars": { "type": "integer", - "title": "Daily Token Limit" + "title": "Daily Cost Limit Microdollars" }, - "weekly_token_limit": { + "weekly_cost_limit_microdollars": { "type": "integer", - "title": "Weekly Token Limit" + "title": "Weekly Cost Limit Microdollars" }, - "daily_tokens_used": { + "daily_cost_used_microdollars": { "type": "integer", - "title": "Daily Tokens Used" + "title": "Daily Cost Used Microdollars" }, - "weekly_tokens_used": { + "weekly_cost_used_microdollars": { "type": "integer", - "title": "Weekly Tokens Used" + "title": "Weekly Cost Used Microdollars" }, "tier": { "$ref": "#/components/schemas/SubscriptionTier" } }, "type": "object", "required": [ "user_id", - "daily_token_limit", - "weekly_token_limit", - "daily_tokens_used", - "weekly_tokens_used", + "daily_cost_limit_microdollars", + "weekly_cost_limit_microdollars", + "daily_cost_used_microdollars", + "weekly_cost_used_microdollars", "tier" ], "title": "UserRateLimitResponse" From f238c153a5bb445a99d1cd71228783584db08e39 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 16:27:01 +0700 Subject: [PATCH 05/41] fix(backend/copilot): release session cluster lock on completion (#12867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes a bug where a chat session gets silently stuck after the user presses Stop mid-turn. **Root cause:** the cancel endpoint marks the session `failed` after polling 5s, but the cluster lock held by the still-running task is only released by `on_run_done` when the task actually finishes. If the task hangs past the 5s poll (slow LLM call, agent-browser step, etc.), the lock lingers for up to 5 min — `stream_chat_post`'s `is_turn_in_flight` check sees the flipped meta (`failed`) and enqueues a new turn, but the run handler sees the stale lock and drops the user's message at `manager.py:379` (`reject+requeue=False`). The new SSE stream hangs until its 60s idle timeout. ### Fix Two cooperating changes: 1. **`mark_session_completed` force-releases the cluster lock** in the same transaction that flips status to `completed`/`failed`. Unconditional delete — by the time we're declaring the session dead, we don't care who the current lock holder is; the lock has to go so the next enqueued turn can acquire. This is what closes the stuck-session window. 2. **`ClusterLock.release()` is now owner-checked** (Lua CAS — `GET == token ? DEL : noop` atomically). Force-release means another pod may legitimately own the key by the time the original task's `on_run_done` eventually fires. Without the CAS, that late `release()` would wipe the successor's lock. With it, the late `release()` is a safe no-op when the owner has changed. Together: prompt release on completion (via force-delete) + safe cleanup when on_run_done catches up (via CAS). That re-syncs the API-level `is_turn_in_flight` check with the actual lock state, so the contention window disappears. No changes to the worker-level contention handler: `stream_chat_post` already queues incoming messages into the pending buffer when a turn is in flight (via `queue_pending_for_http`). With these fixes, the worker never sees contention in the common case; if it does (true multi-pod race), the pre-existing `reject+requeue=False` behaviour still applies — we'll revisit that path with its own PR if it becomes a production symptom. ### Verification - Reproduced the original stuck-session symptom locally (Stop mid-turn → send new message → backend logs `Session … already running on pod …`, user message silently lost, SSE stream idle 60s then closes). - After the fix: cancel → new message → turn starts normally (lock released by `mark_session_completed`). - `poetry run pyright` — 0 errors on edited files. - `pytest backend/copilot/stream_registry_test.py backend/executor/cluster_lock_test.py` — 33 passed (includes the successor-not-wiped test). ## Changes - `autogpt_platform/backend/backend/copilot/executor/utils.py` — extract `get_session_lock_key(session_id)` helper so the lock-key format has a single source of truth. - `autogpt_platform/backend/backend/copilot/executor/manager.py` — use the helper where the cluster lock is created. - `autogpt_platform/backend/backend/copilot/stream_registry.py` — `mark_session_completed` deletes the lock key after the atomic status swap (force-release). - `autogpt_platform/backend/backend/executor/cluster_lock.py` — `ClusterLock.release()` (sync + async) uses a Lua CAS to only delete when `GET == token`, protecting against wiping a successor after a force-release. ## Test plan - [ ] Send a message in /copilot that triggers a long turn (e.g. `run_agent`), press Stop before it finishes, then send another message. Expect: new turn starts promptly (no 5-min wait for lock TTL). - [ ] Happy path regression — send a normal message, verify turn completes and the session lock key is deleted after completion. - [ ] Successor protection — unit test `test_release_does_not_wipe_successor_lock` covers: A acquires, external DEL, B acquires, A.release() is a no-op, B's lock intact. --- .../backend/copilot/executor/manager.py | 3 +- .../backend/backend/copilot/executor/utils.py | 6 + .../backend/copilot/stream_registry.py | 11 +- .../backend/copilot/stream_registry_test.py | 114 ++++++++++++++++++ .../backend/backend/executor/cluster_lock.py | 31 ++++- .../backend/executor/cluster_lock_test.py | 27 +++++ 6 files changed, 185 insertions(+), 7 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/executor/manager.py b/autogpt_platform/backend/backend/copilot/executor/manager.py index da113ccc50..02a2913883 100644 --- a/autogpt_platform/backend/backend/copilot/executor/manager.py +++ b/autogpt_platform/backend/backend/copilot/executor/manager.py @@ -34,6 +34,7 @@ from .utils import ( CancelCoPilotEvent, CoPilotExecutionEntry, create_copilot_queue_config, + get_session_lock_key, ) logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]") @@ -366,7 +367,7 @@ class CoPilotExecutor(AppProcess): # Try to acquire cluster-wide lock cluster_lock = ClusterLock( redis=redis.get_redis(), - key=f"copilot:session:{session_id}:lock", + key=get_session_lock_key(session_id), owner_id=self.executor_id, timeout=settings.config.cluster_lock_timeout, ) diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py index b96e1821a1..a2b051d82b 100644 --- a/autogpt_platform/backend/backend/copilot/executor/utils.py +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -82,6 +82,12 @@ COPILOT_CANCEL_EXCHANGE = Exchange( ) COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue" + +def get_session_lock_key(session_id: str) -> str: + """Redis key for the per-session cluster lock held by the executing pod.""" + return f"copilot:session:{session_id}:lock" + + # CoPilot operations can include extended thinking and agent generation # which may take 30+ minutes to complete COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index f4a26b7008..424964e075 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -35,7 +35,7 @@ from backend.data.redis_client import get_redis_async from backend.data.redis_helpers import hash_compare_and_set from .config import ChatConfig -from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS +from .executor.utils import COPILOT_CONSUMER_TIMEOUT_SECONDS, get_session_lock_key from .response_model import ( ResponseType, StreamBaseResponse, @@ -851,6 +851,15 @@ async def mark_session_completed( logger.debug(f"Session {session_id} already completed/failed, skipping") return False + # Force-release the executor's cluster lock so the next enqueued turn can + # acquire it immediately. The lock holder's on_run_done will also release + # (idempotent delete); doing it here unblocks cases where the task hangs + # past the cancel timeout or a pod crash leaves the lock orphaned. + try: + await redis.delete(get_session_lock_key(session_id)) + except RedisError as e: + logger.warning(f"Failed to release cluster lock for session {session_id}: {e}") + if error_message and not skip_error_publish: try: await publish_chunk(turn_id, StreamError(errorText=error_message)) diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py index 28ec199025..db26a5f524 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry_test.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -4,8 +4,10 @@ import asyncio from unittest.mock import AsyncMock, patch import pytest +from redis.exceptions import RedisError from backend.copilot import stream_registry +from backend.copilot.executor.utils import get_session_lock_key @pytest.fixture(autouse=True) @@ -221,3 +223,115 @@ async def test_stream_and_publish_consumer_break_then_aclose_releases_inner(): await wrapper.aclose() assert inner_finally_ran.is_set() + + +# --------------------------------------------------------------------------- +# mark_session_completed: the atomic meta flip to completed/failed must also +# release the per-session cluster lock, so the next enqueued turn's run +# handler can acquire it without waiting for the TTL (5 min default). +# --------------------------------------------------------------------------- + + +class _FakeRedis: + """Minimal async-Redis fake: only the calls mark_session_completed makes.""" + + def __init__(self, meta: dict[str, str]): + self._meta = dict(meta) + self.deleted_keys: list[str] = [] + self.delete = AsyncMock(side_effect=self._record_delete) + + async def _record_delete(self, *keys: str): + self.deleted_keys.extend(keys) + for k in keys: + self._meta.pop(k, None) + return len(keys) + + async def hgetall(self, _key: str): + return dict(self._meta) + + +@pytest.mark.asyncio +async def test_mark_session_completed_releases_cluster_lock_on_success(): + """CAS swap must be followed by a DELETE on the session's lock key so a + stuck-because-of-stale-lock session becomes immediately claimable.""" + fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"}) + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True) + ), + patch.object(stream_registry, "publish_chunk", new=AsyncMock()), + patch.object( + stream_registry.chat_db(), + "set_turn_duration", + new=AsyncMock(), + create=True, + ), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is True + assert get_session_lock_key("sess-1") in fake_redis.deleted_keys + + +@pytest.mark.asyncio +async def test_mark_session_completed_skips_lock_release_when_already_completed(): + """CAS failure = someone else completed the session first; we must not + delete their already-released lock, and we must NOT publish StreamFinish + twice (the winning caller already published it).""" + fake_redis = _FakeRedis({"status": "completed", "turn_id": "turn-1"}) + publish_mock = AsyncMock() + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=False) + ), + patch.object(stream_registry, "publish_chunk", new=publish_mock), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is False + assert get_session_lock_key("sess-1") not in fake_redis.deleted_keys + assert not any( + isinstance(call.args[1], stream_registry.StreamFinish) + for call in publish_mock.call_args_list + ), "StreamFinish must NOT be re-published on the CAS-no-op branch" + + +@pytest.mark.asyncio +async def test_mark_session_completed_survives_lock_release_redis_error(): + """A Redis hiccup during lock DELETE must not prevent the StreamFinish + publish — the client's SSE stream would otherwise hang on the stale meta + status while Redis recovers.""" + fake_redis = _FakeRedis({"status": "running", "turn_id": "turn-1"}) + fake_redis.delete = AsyncMock(side_effect=RedisError("boom")) + publish_mock = AsyncMock() + + with ( + patch.object( + stream_registry, "get_redis_async", new=AsyncMock(return_value=fake_redis) + ), + patch.object( + stream_registry, "hash_compare_and_set", new=AsyncMock(return_value=True) + ), + patch.object(stream_registry, "publish_chunk", new=publish_mock), + patch.object( + stream_registry.chat_db(), + "set_turn_duration", + new=AsyncMock(), + create=True, + ), + ): + result = await stream_registry.mark_session_completed("sess-1") + + assert result is True + assert any( + isinstance(call.args[1], stream_registry.StreamFinish) + for call in publish_mock.call_args_list + ), "StreamFinish must still be published even if lock DELETE raises" diff --git a/autogpt_platform/backend/backend/executor/cluster_lock.py b/autogpt_platform/backend/backend/executor/cluster_lock.py index 0732c3f6de..9fe8b744c4 100644 --- a/autogpt_platform/backend/backend/executor/cluster_lock.py +++ b/autogpt_platform/backend/backend/executor/cluster_lock.py @@ -4,7 +4,7 @@ import asyncio import logging import threading import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from redis import Redis @@ -12,6 +12,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Lua CAS release: only delete the key if the stored value still matches our +# owner_id. Returns 1 on delete, 0 on no-op. This makes release() safe against +# the race where an external caller (e.g. mark_session_completed's force-release) +# deletes our key and a new owner acquires it before our release() fires — without +# the CAS guard, release() would wipe the successor's valid lock. +_RELEASE_LUA = ( + "if redis.call('get', KEYS[1]) == ARGV[1] then " + "return redis.call('del', KEYS[1]) " + "else return 0 end" +) + class ClusterLock: """Simple Redis-based distributed lock for preventing duplicate execution.""" @@ -116,13 +127,18 @@ class ClusterLock: return False def release(self): - """Release the lock.""" + """Release the lock. + + Owner-checked: only deletes the Redis key if the stored value still + matches our owner_id. Prevents wiping a successor's lock when the + original key was force-released externally and re-acquired. + """ with self._refresh_lock: if self._last_refresh == 0: return try: - self.redis.delete(self.key) + self.redis.eval(_RELEASE_LUA, 1, self.key, self.owner_id) except Exception: pass @@ -237,13 +253,18 @@ class AsyncClusterLock: return False async def release(self): - """Release the lock.""" + """Release the lock. + + Owner-checked: only deletes the Redis key if the stored value still + matches our owner_id. Prevents wiping a successor's lock when the + original key was force-released externally and re-acquired. + """ async with self._refresh_lock: if self._last_refresh == 0: return try: - await self.redis.delete(self.key) + await cast(Any, self.redis.eval(_RELEASE_LUA, 1, self.key, self.owner_id)) except Exception: pass diff --git a/autogpt_platform/backend/backend/executor/cluster_lock_test.py b/autogpt_platform/backend/backend/executor/cluster_lock_test.py index c5d8965f0f..5491c51cad 100644 --- a/autogpt_platform/backend/backend/executor/cluster_lock_test.py +++ b/autogpt_platform/backend/backend/executor/cluster_lock_test.py @@ -108,6 +108,33 @@ class TestClusterLockBasic: new_lock = ClusterLock(redis_client, lock_key, new_owner_id, timeout=60) assert new_lock.try_acquire() == new_owner_id + def test_release_does_not_wipe_successor_lock(self, redis_client, lock_key): + """Releasing after external delete+reacquire must NOT delete successor. + + Race: an external caller force-deletes the lock key, a new owner + acquires it, then the original ClusterLock.release() runs. Owner-checked + release must leave the successor's key intact. + """ + owner_a = str(uuid.uuid4()) + owner_b = str(uuid.uuid4()) + + lock_a = ClusterLock(redis_client, lock_key, owner_a, timeout=60) + assert lock_a.try_acquire() == owner_a + + # External force-release (e.g. mark_session_completed). + redis_client.delete(lock_key) + + # Successor acquires the same key. + lock_b = ClusterLock(redis_client, lock_key, owner_b, timeout=60) + assert lock_b.try_acquire() == owner_b + + # Original releases — must be a no-op on Redis because value != owner_a. + lock_a.release() + + # Successor's lock is still intact. + assert redis_client.exists(lock_key) == 1 + assert redis_client.get(lock_key).decode("utf-8") == owner_b + class TestClusterLockRefresh: """Lock refresh and TTL management.""" From e17e9f13c4c6832eb6bfa869534181fe37b8fa6c Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 16:34:10 +0700 Subject: [PATCH 06/41] fix(backend/copilot): reduce SDK + baseline prompt cache waste (#12866) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Four cost-reduction changes for the copilot feature. Consolidated into one PR at user request; each commit is self-contained and bisectable. ### 1. SDK: full cross-user cache on every turn (CLI 2.1.116 bump) Previous behavior: CLI 2.1.97 crashed when `excludeDynamicSections=True` was combined with `--resume`, so the code fell back to a raw `system_prompt` string on resume, losing Claude Code's default prompt and all cache markers. Every Turn 2+ of an SDK session wrote ~33K tokens to cache instead of reading. Fix: install `@anthropic-ai/claude-code@2.1.116` in the backend Docker image and point the SDK at it via `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude`. CLI 2.1.98+ fixes the crash, so we can use the preset with `exclude_dynamic_sections=True` on every turn — Turn 1, 2, 3+ all share the same static prefix and hit the **cross-user** prompt cache. **Local dev requirement:** if `CHAT_CLAUDE_AGENT_CLI_PATH` is unset, the bundled 2.1.97 fallback will crash on `--resume`. Install the CLI globally (`npm install -g @anthropic-ai/claude-code@2.1.116`) or set the env var. ### 2. Baseline: add `cache_control` markers (commit `756b3ecd9` + follow-ups) Baseline path had zero `cache_control` across `backend/copilot/**`. Every turn was full uncached input (~18.6K tokens, ~$0.058). Two ephemeral markers — on the system message (content-blocks form) and the last tool schema — plus `anthropic-beta: prompt-caching-2024-07-31` via `extra_headers` as defense-in-depth. Helpers split into `_mark_tools_*` (precomputed once per session) and `_mark_system_*` (per-round, O(1)). Repeat hellos: ~$0.058 → ~$0.006. ### 3. Drop `get_baseline_supplement()` (commit `6e6c4d791`) `_generate_tool_documentation()` emitted ~4.3K tokens of `(tool_name, description)` pairs that exactly duplicated the tools array already in the same request. Deleted. `SHARED_TOOL_NOTES` (cross-tool workflow rules) is preserved. Baseline "hello" input: ~18.7K → ~14.4K tokens. ### 4. Langfuse "CoPilot Prompt" v26 (published under `review` label) Separate, out-of-repo change. v25 had three duplicate "Example Response" blocks + a 10-step "Internal Reasoning Process" section. v26 collapses to one example + bullet-form reasoning. Char count 20,481 → 7,075 (rough 4 chars/token → ~5,100 → ~1,770 tokens). - v26 is published with label `review` (NOT `production`); v25 remains active. - Promote via `mcp__langfuse__updatePromptLabels(name="CoPilot Prompt", version=26, newLabels=["production"])` after smoke-test. - Rollback: relabel v25 `production`. ## Test plan - [x] Unit tests for `_build_system_prompt_value` (fresh vs resumed turns emit identical preset dict) - [x] SDK compat tests pass including `test_bundled_cli_version_is_known_good_against_openrouter` - [x] `cli_openrouter_compat_test.py` passes against CLI 2.1.116 (locally verified with `CHAT_CLAUDE_AGENT_CLI_PATH=/opt/homebrew/bin/claude`) - [x] 8 new `_mark_*` unit tests + identity regression test for `_fresh_*` helpers - [x] `SHARED_TOOL_NOTES` public-constant test passes; 5 old tool-docs tests removed - [ ] **Manual cost verification (commit 1):** send two consecutive SDK turns; Turn 2 and Turn 3 should both show `cacheReadTokens` ≈ 33K (full cross-user cache hits). - [ ] **Manual cost verification (commit 2):** send two "hello" turns on baseline <5 min apart; Turn 2 reports `cacheReadTokens` ≈ 18K and cost ≈ $0.006. - [ ] **Regression sweep for commit 3:** one turn per tool family — `search_agents`, `run_agent`, `add_memory`/`forget_memory`/`search_memory`, `search_docs`, `read_workspace_file` — to verify no tool-selection regression from dropping the prose tool docs. - [ ] **Langfuse v26 smoke test:** 5-10 varied turns after relabelling to `production`; compare responses vs v25 for regression on persona, concision, capability-gap handling, credential security flows. ## Deployment notes - Production Docker image now installs CLI 2.1.116 (~20 MB added). - `CHAT_CLAUDE_AGENT_CLI_PATH=/usr/bin/claude` set in the Dockerfile; runtime can override via env. - First deploy after this merge needs a fresh image rebuild to pick up the new CLI. --- .../backend/copilot/baseline/service.py | 251 ++++++++++++-- .../copilot/baseline/service_unit_test.py | 309 +++++++++++++++++- .../backend/backend/copilot/config.py | 12 + .../backend/backend/copilot/prompting.py | 55 +--- .../backend/copilot/sdk/sdk_compat_test.py | 23 +- .../backend/backend/copilot/sdk/service.py | 46 +-- .../backend/copilot/sdk/service_test.py | 100 ++---- autogpt_platform/backend/poetry.lock | 20 +- autogpt_platform/backend/pyproject.toml | 2 +- 9 files changed, 622 insertions(+), 196 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 8a26002e25..4e495264c8 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -15,7 +15,7 @@ import re import shutil import tempfile import uuid -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, Mapping, Sequence from dataclasses import dataclass, field from functools import partial from typing import TYPE_CHECKING, Any, cast @@ -47,7 +47,7 @@ from backend.copilot.pending_messages import ( drain_pending_messages, format_pending_as_user_message, ) -from backend.copilot.prompting import get_baseline_supplement, get_graphiti_supplement +from backend.copilot.prompting import SHARED_TOOL_NOTES, get_graphiti_supplement from backend.copilot.response_model import ( StreamBaseResponse, StreamError, @@ -168,12 +168,37 @@ def _extract_usage_cost(usage: CompletionUsage) -> float | None: def _extract_cache_creation_tokens(ptd: PromptTokensDetails) -> int: - """Read Anthropic's ``cache_creation_input_tokens`` off an OpenAI - ``PromptTokensDetails`` — it's a provider-specific extra, not in the - typed model, so we read it via ``model_extra`` rather than - ``getattr``. + """Return cache-write token count from an OpenAI-compatible + ``PromptTokensDetails``, handling provider-specific field names and + SDK-version shape differences. + + Two shapes we care about: + + - **OpenRouter** (our primary baseline provider) streams the cache-write + count as ``cache_write_tokens``. Newer ``openai-python`` versions + declare this as a typed attribute on ``PromptTokensDetails``; older + versions expose it only in ``model_extra``. Verified empirically: + cold-cache request returns ``cache_write_tokens`` > 0, warm-cache + request returns ``cached_tokens`` > 0 and ``cache_write_tokens`` = 0. + - **Direct Anthropic API** uses ``cache_creation_input_tokens`` — + never a typed attribute on the OpenAI SDK, always lives in + ``model_extra``. + + Lookup order: typed attr → ``model_extra`` (OpenRouter) → ``model_extra`` + (Anthropic-native). ``getattr`` handles both the typed-attr case + (newer SDK) and the no-such-attr case (older SDK) — we can't only use + ``model_extra`` because when the field is typed it's filtered out of + ``model_extra``, leaving us at 0 on the modern happy path. """ - return int((ptd.model_extra or {}).get("cache_creation_input_tokens") or 0) + typed_val = getattr(ptd, "cache_write_tokens", None) + if typed_val: + return int(typed_val) + extras = ptd.model_extra or {} + return int( + extras.get("cache_write_tokens") + or extras.get("cache_creation_input_tokens") + or 0 + ) async def _prepare_baseline_attachments( @@ -327,6 +352,137 @@ class _BaselineStreamState: # block only appends the *new* assistant text (avoiding duplication of # round-1 text when round-1 entries were cleared from session_messages). _flushed_assistant_text_len: int = 0 + # Memoised system-message dict with cache_control applied. The system + # prompt is static within a session, so we build it once on the first + # LLM round and reuse the same dict on subsequent rounds — avoiding + # an O(N) dict-copy of the growing ``messages`` list on every tool-call + # iteration. ``None`` means "not yet computed" (or the first message + # wasn't a system role, so no marking applies). + cached_system_message: dict[str, Any] | None = None + + +def _is_anthropic_model(model: str) -> bool: + """Return True if *model* routes to Anthropic (native or via OpenRouter). + + Cache-control markers on message content + the ``anthropic-beta`` header + are Anthropic-specific. OpenAI rejects the unknown ``cache_control`` + field with a 400 ("Extra inputs are not permitted") and Grok / other + providers behave similarly. OpenRouter strips unknown headers but + passes through ``cache_control`` on the body regardless of provider — + which would also fail when OpenRouter routes to a non-Anthropic model. + + Examples that return True: + - ``anthropic/claude-sonnet-4-6`` (OpenRouter route) + - ``claude-3-5-sonnet-20241022`` (direct Anthropic API) + - ``anthropic.claude-3-5-sonnet`` (Bedrock-style) + + False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4`` + etc. + """ + lowered = model.lower() + return "claude" in lowered or lowered.startswith("anthropic") + + +def _fresh_ephemeral_cache_control() -> dict[str, str]: + """Return a FRESH ephemeral ``cache_control`` dict each call. + + The ``ttl`` is sourced from :attr:`ChatConfig.baseline_prompt_cache_ttl` + (default ``1h``) so the static prefix stays warm across many users' + requests in the same workspace cache. Anthropic caches are keyed + per-workspace, so every copilot user reading the same system prompt + hits the same cached entry. + + Using a shared module-level dict would let any downstream mutation + (e.g. the OpenAI SDK normalising fields in-place) poison every future + request's marker. Construction is O(1) so the safety margin is free. + """ + return {"type": "ephemeral", "ttl": config.baseline_prompt_cache_ttl} + + +def _fresh_anthropic_caching_headers() -> dict[str, str]: + """Return a FRESH ``extra_headers`` dict requesting the Anthropic + prompt-caching beta. + + Same reasoning as :func:`_fresh_ephemeral_cache_control`: never hand a + shared module-level dict to third-party SDKs. OpenRouter auto-forwards + cache_control for Anthropic routes without this header, but passing it + makes the intent unambiguous on-wire and is a no-op for non-Anthropic + providers (unknown headers are dropped). + """ + return {"anthropic-beta": "prompt-caching-2024-07-31"} + + +def _mark_tools_with_cache_control( + tools: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *tools* with ``cache_control`` on the last entry. + + Marking the last tool is a cache breakpoint that covers the whole tool + schema block as a cacheable prefix segment. Extracted from + :func:`_mark_system_message_with_cache_control` so callers can precompute + the marked tool list once per session — the tool set is static within a + request and the ~43 dict-copies would otherwise run on every LLM round + in the tool-call loop. + + **Only call this for Anthropic model routes.** Non-Anthropic providers + (OpenAI, Grok, Gemini) reject the unknown ``cache_control`` field with + a 400 schema validation error. Gate via :func:`_is_anthropic_model`. + """ + cached: list[dict[str, Any]] = [dict(t) for t in tools] + if cached: + cached[-1] = { + **cached[-1], + "cache_control": _fresh_ephemeral_cache_control(), + } + return cached + + +def _build_cached_system_message( + system_message: Mapping[str, Any], +) -> dict[str, Any]: + """Return a copy of *system_message* with ``cache_control`` applied. + + Anthropic's cache uses prefix-match with up to 4 explicit breakpoints. + Combined with the last-tool marker this gives two cache segments — the + system block alone, and system+all-tools — so requests that share only + the system prefix still get a partial cache hit. + + The system message is rebuilt via spread (``{**original, ...}``) so any + unknown fields the caller set (e.g. ``name``) survive the transformation. + Non-Anthropic models silently ignore the markers. + + Returns the original dict (shallow-copied) unchanged when the content + shape is unsupported (missing / non-string / empty) — callers should + splice it into the message list as-is in that case. + """ + sys_copy = dict(system_message) + sys_content = sys_copy.get("content") + if isinstance(sys_content, str) and sys_content: + sys_copy["content"] = [ + { + "type": "text", + "text": sys_content, + "cache_control": _fresh_ephemeral_cache_control(), + } + ] + return sys_copy + + +def _mark_system_message_with_cache_control( + messages: Sequence[Mapping[str, Any]], +) -> list[dict[str, Any]]: + """Return a copy of *messages* with ``cache_control`` on the system block. + + Thin wrapper around :func:`_build_cached_system_message` that preserves + the original list shape. Prefer the memoised path in + ``_baseline_llm_caller`` (which builds the cached system dict once per + session) for hot-loop callers; this function is retained for call sites + outside the tool-call loop where per-call copying is acceptable. + """ + cached_messages: list[dict[str, Any]] = [dict(m) for m in messages] + if cached_messages and cached_messages[0].get("role") == "system": + cached_messages[0] = _build_cached_system_message(cached_messages[0]) + return cached_messages async def _baseline_llm_caller( @@ -347,28 +503,51 @@ async def _baseline_llm_caller( round_text = "" try: client = _get_openai_client() - typed_messages = cast(list[ChatCompletionMessageParam], messages) - # extra_body `usage.include=true` asks OpenRouter to embed the real - # generation cost into the final usage chunk. Without this we only get - # token counts and have no authoritative cost for rate limiting. - if tools: - typed_tools = cast(list[ChatCompletionToolParam], tools) - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - tools=typed_tools, - stream=True, - stream_options={"include_usage": True}, - extra_body=_OPENROUTER_INCLUDE_USAGE_COST, - ) + # Cache markers are Anthropic-specific. For OpenAI/Grok/other + # providers, leaving them on would trigger a 400 ("Extra inputs + # are not permitted" on cache_control). Tools were precomputed + # in stream_chat_completion_baseline via _mark_tools_with_cache_control + # (only when the model was Anthropic), so on non-Anthropic routes + # tools ship without cache_control on the last entry too. + # + # `extra_body` `usage.include=true` asks OpenRouter to embed the real + # generation cost into the final usage chunk — required by the + # cost-based rate limiter in routes.py. Separate from the Anthropic + # caching headers, always sent. + is_anthropic = _is_anthropic_model(state.model) + if is_anthropic: + # Build the cached system dict once per session and splice it in + # on each round. The full ``messages`` list grows with every + # tool call, so copying the entire list just to mutate index 0 + # scales with conversation length (sentry flagged this); this + # splice touches only list slots, not message contents. + if ( + state.cached_system_message is None + and messages + and messages[0].get("role") == "system" + ): + state.cached_system_message = _build_cached_system_message(messages[0]) + if state.cached_system_message is not None and messages: + final_messages = [state.cached_system_message, *messages[1:]] + else: + final_messages = messages + extra_headers = _fresh_anthropic_caching_headers() else: - response = await client.chat.completions.create( - model=state.model, - messages=typed_messages, - stream=True, - stream_options={"include_usage": True}, - extra_body=_OPENROUTER_INCLUDE_USAGE_COST, - ) + final_messages = messages + extra_headers = None + typed_messages = cast(list[ChatCompletionMessageParam], final_messages) + create_kwargs: dict[str, Any] = { + "model": state.model, + "messages": typed_messages, + "stream": True, + "stream_options": {"include_usage": True}, + "extra_body": _OPENROUTER_INCLUDE_USAGE_COST, + } + if extra_headers: + create_kwargs["extra_headers"] = extra_headers + if tools: + create_kwargs["tools"] = cast(list[ChatCompletionToolParam], list(tools)) + response = await client.chat.completions.create(**create_kwargs) tool_calls_by_index: dict[int, dict[str, str]] = {} # Iterate under an inner try/finally so early exits (cancel, tool-call @@ -1170,7 +1349,7 @@ async def stream_chat_completion_baseline( graphiti_enabled = await is_enabled_for_user(user_id) graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" - system_prompt = base_system_prompt + get_baseline_supplement() + graphiti_supplement + system_prompt = base_system_prompt + SHARED_TOOL_NOTES + graphiti_supplement # Warm context: pre-load relevant facts from Graphiti on first turn. # Use the pre-drain count so pending messages drained at turn start @@ -1320,6 +1499,18 @@ async def stream_chat_completion_baseline( if permissions is not None: tools = _filter_tools_by_permissions(tools, permissions) + # Pre-mark cache_control on the last tool schema once per session. The + # tool set is static within a request, so doing this here (instead of in + # _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM + # round of the tool-call loop. + # + # Only apply to Anthropic routes — OpenAI/Grok/other providers would + # 400 on the unknown ``cache_control`` field inside tool definitions. + if _is_anthropic_model(active_model): + tools = cast( + list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools) + ) + # Propagate execution context so tool handlers can read session-level flags. set_execution_context( user_id, @@ -1707,6 +1898,8 @@ async def stream_chat_completion_baseline( prompt_tokens=billed_prompt, completion_tokens=state.turn_completion_tokens, total_tokens=billed_prompt + state.turn_completion_tokens, + cache_read_tokens=state.turn_cache_read_tokens, + cache_creation_tokens=state.turn_cache_creation_tokens, ) yield StreamFinish() diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index e21618c367..4e70767426 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -13,7 +13,14 @@ from backend.copilot.baseline.service import ( _baseline_conversation_updater, _baseline_llm_caller, _BaselineStreamState, + _build_cached_system_message, _compress_session_messages, + _extract_cache_creation_tokens, + _fresh_anthropic_caching_headers, + _fresh_ephemeral_cache_control, + _is_anthropic_model, + _mark_system_message_with_cache_control, + _mark_tools_with_cache_control, ) from backend.copilot.model import ChatMessage from backend.copilot.transcript_builder import TranscriptBuilder @@ -605,11 +612,18 @@ def _make_usage_chunk( chunk.usage.model_extra = usage_extras if cached_tokens is not None or cache_creation_input_tokens is not None: - ptd = MagicMock() - ptd.cached_tokens = cached_tokens or 0 - ptd.model_extra = { - "cache_creation_input_tokens": cache_creation_input_tokens or 0 - } + # Build a real ``PromptTokensDetails`` so ``getattr(ptd, + # "cache_write_tokens", None)`` returns ``None`` on this SDK version + # (rather than a truthy MagicMock attribute) and the extraction + # helper's typed-attr vs model_extra fallback resolves correctly. + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": cached_tokens or 0}) + if cache_creation_input_tokens is not None: + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = cache_creation_input_tokens chunk.usage.prompt_tokens_details = ptd else: chunk.usage.prompt_tokens_details = None @@ -1209,3 +1223,288 @@ class TestMidLoopPendingFlushOrdering: assert assistant_msgs[1].tool_calls is None # Crucially: only 2 assistant messages, not 3 (no duplicate) assert len(assistant_msgs) == 2 + + +class TestApplyPromptCacheMarkers: + """Tests for _apply_prompt_cache_markers — Anthropic ephemeral + cache_control markers on baseline OpenRouter requests.""" + + def test_system_message_converted_to_content_blocks(self): + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["role"] == "system" + assert cached_messages[0]["content"] == [ + { + "type": "text", + "text": "You are helpful.", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + # User message must be untouched. + assert cached_messages[1] == {"role": "user", "content": "hello"} + + def test_system_message_preserves_unknown_fields(self): + # Future-proofing: a system message with extra keys (e.g. "name") must + # keep them after the content-blocks conversion. + messages = [ + {"role": "system", "content": "sys", "name": "developer"}, + ] + + cached_messages = _mark_system_message_with_cache_control(messages) + + assert cached_messages[0]["name"] == "developer" + assert cached_messages[0]["role"] == "system" + + def test_last_tool_gets_cache_control(self): + tools = [ + {"type": "function", "function": {"name": "a"}}, + {"type": "function", "function": {"name": "b"}}, + ] + + cached_tools = _mark_tools_with_cache_control(tools) + + assert "cache_control" not in cached_tools[0] + assert cached_tools[-1]["cache_control"] == { + "type": "ephemeral", + "ttl": "1h", + } + # Last tool's other fields preserved. + assert cached_tools[-1]["function"] == {"name": "b"} + + def test_does_not_mutate_input(self): + messages = [{"role": "system", "content": "sys"}] + tools = [{"type": "function", "function": {"name": "a"}}] + + _mark_system_message_with_cache_control(messages) + _mark_tools_with_cache_control(tools) + + assert messages == [{"role": "system", "content": "sys"}] + assert tools == [{"type": "function", "function": {"name": "a"}}] + + def test_no_system_message_safe(self): + messages = [{"role": "user", "content": "hi"}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages == messages + + def test_empty_tools_safe(self): + assert _mark_tools_with_cache_control([]) == [] + + def test_non_string_system_content_left_untouched(self): + # If the content is already a list of blocks (e.g. caller pre-marked), + # the helper must not overwrite it. + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + messages = [{"role": "system", "content": pre_marked}] + cached_messages = _mark_system_message_with_cache_control(messages) + assert cached_messages[0]["content"] == pre_marked + + def test_is_anthropic_model_matches_claude_and_anthropic_prefix(self): + assert _is_anthropic_model("anthropic/claude-sonnet-4-6") + assert _is_anthropic_model("claude-3-5-sonnet-20241022") + assert _is_anthropic_model("anthropic.claude-3-5-sonnet-20241022-v2:0") + assert _is_anthropic_model("ANTHROPIC/Claude-Opus") # case insensitive + + def test_is_anthropic_model_rejects_other_providers(self): + assert not _is_anthropic_model("openai/gpt-4o") + assert not _is_anthropic_model("openai/gpt-5") + assert not _is_anthropic_model("google/gemini-2.5-pro") + assert not _is_anthropic_model("xai/grok-4") + assert not _is_anthropic_model("meta-llama/llama-3.3-70b-instruct") + + def test_cache_control_uses_configured_ttl(self, monkeypatch): + """TTL comes from ChatConfig.baseline_prompt_cache_ttl — defaults + to 1h so the static prefix (system + tools) stays warm across + workspace users past the 5-min default window.""" + from backend.copilot.baseline import service as bsvc + + assert bsvc.config.baseline_prompt_cache_ttl == "1h" + cc = bsvc._fresh_ephemeral_cache_control() + assert cc == {"type": "ephemeral", "ttl": "1h"} + monkeypatch.setattr(bsvc.config, "baseline_prompt_cache_ttl", "5m") + assert bsvc._fresh_ephemeral_cache_control() == { + "type": "ephemeral", + "ttl": "5m", + } + + def test_fresh_helpers_return_distinct_objects(self): + """Regression guard: the `_fresh_*` helpers must return a NEW dict + on every call. A future refactor returning a module-level constant + would silently reintroduce the shared-mutable-state bug flagged + during earlier review cycles.""" + assert _fresh_ephemeral_cache_control() is not _fresh_ephemeral_cache_control() + assert ( + _fresh_anthropic_caching_headers() is not _fresh_anthropic_caching_headers() + ) + + def test_extract_cache_creation_tokens_openrouter_typed_attr(self): + """Newer ``openai-python`` declares ``cache_write_tokens`` as a + typed attribute on ``PromptTokensDetails`` — it no longer lands in + ``model_extra``. Verified empirically against the production + openai==1.113 installed in this venv: OpenRouter streaming + response populates ``ptd.cache_write_tokens`` directly while + ``ptd.model_extra`` is ``{}``. + """ + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate( + { + "audio_tokens": 0, + "cached_tokens": 0, + "cache_write_tokens": 4432, + "video_tokens": 0, + } + ) + assert getattr(ptd, "cache_write_tokens", None) == 4432 + assert _extract_cache_creation_tokens(ptd) == 4432 + + def test_extract_cache_creation_tokens_openrouter_model_extra(self): + """Older SDKs that don't yet declare ``cache_write_tokens`` as a + typed field leave it in ``model_extra`` — the helper must still + find it there.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + # Force the value into model_extra (simulates the old SDK shape + # where the field wasn't typed yet). + if ptd.model_extra is None: + # Pydantic v2 sometimes exposes __pydantic_extra__ as None when + # extras are disabled; initialise to a dict to mutate safely. + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_write_tokens"] = 7777 + assert _extract_cache_creation_tokens(ptd) == 7777 + + def test_extract_cache_creation_tokens_anthropic_native_field(self): + """Direct Anthropic API uses ``cache_creation_input_tokens`` — + falls through as the final path when neither + ``cache_write_tokens`` typed attr nor model_extra entry exists.""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + if ptd.model_extra is None: + object.__setattr__(ptd, "__pydantic_extra__", {}) + assert ptd.model_extra is not None + ptd.model_extra["cache_creation_input_tokens"] = 2048 + assert _extract_cache_creation_tokens(ptd) == 2048 + + def test_extract_cache_creation_tokens_absent(self): + """Neither provider field present → 0 (non-Anthropic routes or + cache-miss responses).""" + from openai.types.completion_usage import PromptTokensDetails + + ptd = PromptTokensDetails.model_validate({"cached_tokens": 0}) + assert _extract_cache_creation_tokens(ptd) == 0 + + def test_build_cached_system_message_applies_cache_control(self): + """The single-message helper wraps the string content in a text block + with an ephemeral cache_control marker.""" + out = _build_cached_system_message({"role": "system", "content": "hi"}) + assert out["role"] == "system" + assert out["content"] == [ + { + "type": "text", + "text": "hi", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + + def test_build_cached_system_message_preserves_extra_fields(self): + """Unknown keys (e.g. ``name``) survive the transformation.""" + out = _build_cached_system_message( + {"role": "system", "content": "sys", "name": "dev"} + ) + assert out["name"] == "dev" + assert out["role"] == "system" + + def test_build_cached_system_message_non_string_passthrough(self): + """Pre-marked list content is returned as-is (shallow-copied).""" + pre_marked = [ + { + "type": "text", + "text": "sys", + "cache_control": {"type": "ephemeral", "ttl": "1h"}, + } + ] + out = _build_cached_system_message({"role": "system", "content": pre_marked}) + assert out["content"] is pre_marked + + @pytest.mark.asyncio + async def test_baseline_llm_caller_memoises_cached_system_message(self): + """The cached system dict is built once and reused across rounds. + + Guards against the perf regression where the entire (growing) + ``messages`` list was copied on every tool-call iteration just to + mark the static system prompt. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + side_effect=[_make_stream_mock(chunk), _make_stream_mock(chunk)] + ) + + messages: list[dict] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + first_cached = state.cached_system_message + assert first_cached is not None + # Simulate the tool-call loop growing ``messages`` between rounds. + messages.append({"role": "assistant", "content": "ok"}) + messages.append({"role": "user", "content": "follow up"}) + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + # Same dict instance reused — not rebuilt per round. + assert state.cached_system_message is first_cached + + # Second call's first message is the memoised system dict (not a new copy). + second_call_messages = mock_client.chat.completions.create.call_args_list[1][1][ + "messages" + ] + assert second_call_messages[0] is first_cached + # And the tail messages were spliced in, not re-copied. + assert second_call_messages[1] is messages[1] + assert second_call_messages[-1] is messages[-1] + + @pytest.mark.asyncio + async def test_baseline_llm_caller_skips_memoisation_for_non_anthropic(self): + """Non-Anthropic routes pass messages through unmodified — no cache + dict is built, no list splicing happens.""" + state = _BaselineStreamState(model="openai/gpt-4o") + chunk = _make_usage_chunk(prompt_tokens=10, completion_tokens=5) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(chunk) + ) + + messages: list[dict] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller(messages=messages, tools=[], state=state) + + assert state.cached_system_message is None + # The exact same list object reaches the provider (no copy needed). + call_messages = mock_client.chat.completions.create.call_args[1]["messages"] + assert call_messages is messages diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index 3277854172..1080921fd8 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -225,6 +225,18 @@ class ChatConfig(BaseSettings): "from the prefix. Set to False to fall back to passing the system " "prompt as a raw string.", ) + baseline_prompt_cache_ttl: str = Field( + default="1h", + description="TTL for the ephemeral prompt-cache markers on the baseline " + "OpenRouter path. Anthropic supports only `5m` (default, 1.25x input " + "price for the write) or `1h` (2x input price for the write). 1h is " + "strictly cheaper overall when the static prefix gets >7 reads per " + "write-window; since the system prompt + tools array is identical " + "across all users in our workspace, 1h is the default so cross-user " + "reads amortise the higher write cost. Anthropic has no longer " + "(24h, permanent) TTL option — see " + "https://platform.claude.com/docs/en/build-with-claude/prompt-caching.", + ) claude_agent_cli_path: str | None = Field( default=None, description="Optional explicit path to a Claude Code CLI binary. " diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index 2f52bd460d..399d31c1cc 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -8,10 +8,12 @@ handling the distinction between: from functools import cache -from backend.copilot.tools import TOOL_REGISTRY - -# Shared technical notes that apply to both SDK and baseline modes -_SHARED_TOOL_NOTES = """\ +# Workflow rules appended to the system prompt on every copilot turn +# (baseline appends directly; SDK appends via the storage-supplement +# template). These are cross-tool rules (file sharing, @@agptfile: refs, +# tool-discovery priority, sub-agent etiquette) that don't belong on any +# individual tool schema. +SHARED_TOOL_NOTES = """\ ### Sharing files After `write_workspace_file`, embed the `download_url` in Markdown: @@ -261,7 +263,7 @@ When a tool output contains ``, the full output is in workspace storage (NOT on the local filesystem). To access it: - Use `read_workspace_file(path="...", offset=..., length=50000)` for reading sections. - To process in the sandbox, use `read_workspace_file(path="...", save_to_path="{working_dir}/file.json")` first, then use `bash_exec` on the local copy. -{_SHARED_TOOL_NOTES}{extra_notes}""" +{SHARED_TOOL_NOTES}{extra_notes}""" # Pre-built supplements for common environments @@ -312,35 +314,6 @@ def _get_cloud_sandbox_supplement() -> str: ) -def _generate_tool_documentation() -> str: - """Auto-generate tool documentation from TOOL_REGISTRY. - - NOTE: This is ONLY used in baseline mode (direct OpenAI API). - SDK mode doesn't need it since Claude gets tool schemas automatically. - - This generates a complete list of available tools with their descriptions, - ensuring the documentation stays in sync with the actual tool implementations. - All workflow guidance is now embedded in individual tool descriptions. - - Only documents tools that are available in the current environment - (checked via tool.is_available property). - """ - docs = "\n## AVAILABLE TOOLS\n\n" - - # Sort tools alphabetically for consistent output - # Filter by is_available to match get_available_tools() behavior - for name in sorted(TOOL_REGISTRY.keys()): - tool = TOOL_REGISTRY[name] - if not tool.is_available: - continue - schema = tool.as_openai_tool() - desc = schema["function"].get("description", "No description available") - # Format as bullet list with tool name in code style - docs += f"- **`{name}`**: {desc}\n" - - return docs - - _USER_FOLLOW_UP_NOTE = """ # `` blocks in tool output @@ -438,17 +411,3 @@ You have access to persistent temporal memory tools that remember facts across s - group_id is handled automatically by the system — never set it yourself. - When storing, be specific about operational rules and instructions (e.g., "CC Sarah on client communications" not just "Sarah is the assistant"). """ - - -def get_baseline_supplement() -> str: - """Get the supplement for baseline mode (direct OpenAI API). - - Baseline mode INCLUDES auto-generated tool documentation because the - direct API doesn't automatically provide tool schemas to Claude. - Also includes shared technical notes (but NOT SDK-specific environment details). - - Returns: - The supplement string to append to the system prompt - """ - tool_docs = _generate_tool_documentation() - return tool_docs + _SHARED_TOOL_NOTES diff --git a/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py b/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py index 5d132aa94d..7cf8af3396 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/sdk_compat_test.py @@ -94,21 +94,23 @@ def test_agent_options_accepts_required_fields(): def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections(): """Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces. - The production code always includes ``exclude_dynamic_sections=True`` in the preset - dict. This compat test mirrors that exact shape so any SDK version that starts - rejecting unknown keys will be caught here rather than at runtime. + The Turn 1 (non-resume) code path includes ``exclude_dynamic_sections=True`` in + the preset dict for cross-user caching. This compat test mirrors that exact + shape so any SDK version that starts rejecting unknown keys will be caught + here rather than at runtime. """ from claude_agent_sdk import ClaudeAgentOptions from claude_agent_sdk.types import SystemPromptPreset from .service import _build_system_prompt_value - # Call the production helper directly so this test is tied to the real - # dict shape rather than a hand-rolled copy. preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True) assert isinstance( preset, dict ), "_build_system_prompt_value must return a dict when caching is on" + assert preset.get("exclude_dynamic_sections") is True, ( + "Turn 1 must strip dynamic sections to keep the prefix cacheable " "cross-user" + ) sdk_preset = cast(SystemPromptPreset, preset) opts = ClaudeAgentOptions(system_prompt=sdk_preset) @@ -116,8 +118,9 @@ def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_section def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off(): - """When cross_user_cache=False (e.g. on --resume turns), the helper must return - a plain string so the preset+resume crash is avoided.""" + """When cross_user_cache=False (feature flag disabled globally), the + helper returns a plain string; the CLI will receive --system-prompt + (replace-mode) and skip the preset entirely.""" from .service import _build_system_prompt_value result = _build_system_prompt_value("my prompt", cross_user_cache=False) @@ -262,6 +265,12 @@ _KNOWN_GOOD_BUNDLED_CLI_VERSIONS: frozenset[str] = frozenset( "2.1.97", # claude-agent-sdk 0.1.58 -- OpenRouter-safe only with # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 (injected by # build_sdk_env() in env.py). + "2.1.116", # claude-agent-sdk 0.1.64 -- first bundled version that + # fixes the --resume + excludeDynamicSections=True crash + # (introduced in 2.1.98), unlocking cross-user prompt + # cache reads on every resumed SDK turn. Still requires + # CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1. Verified + # OpenRouter-safe via cli_openrouter_compat_test.py. } ) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index e4f29a2b65..8fe8aa12df 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -836,16 +836,25 @@ def _is_fallback_stderr(line: str) -> bool: def _build_system_prompt_value( system_prompt: str, + *, cross_user_cache: bool, ) -> str | SystemPromptPreset: """Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`. When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset` - dict so the Claude Code default prompt becomes a cacheable prefix shared - across all users; our custom *system_prompt* is appended after it. + with ``exclude_dynamic_sections=True`` so every turn — Turn 1 *and* + resumed turns — shares the same static prefix and hits the cross-user + prompt cache. Our custom *system_prompt* is appended after the preset. - When disabled (or if the SDK is too old to support ``SystemPromptPreset``), - the raw *system_prompt* string is returned unchanged. + Requires CLI ≥ 2.1.98 (older CLIs crash when ``excludeDynamicSections`` + is combined with ``--resume``). The SDK bundles CLI 2.1.116 at + ``claude-agent-sdk >= 0.1.64``, so the pin in ``pyproject.toml`` is + the single source of truth — no external install needed. + + When *cross_user_cache* is disabled, the raw *system_prompt* string is + returned. Note this causes the CLI to REPLACE its built-in prompt via + ``--system-prompt`` (vs ``--append-system-prompt`` for the preset), + which loses Claude Code's default prompt and its cache markers entirely. An empty *system_prompt* is accepted: the preset dict will have ``append: ""`` which the SDK treats as no custom suffix. @@ -3036,15 +3045,17 @@ async def stream_chat_completion_sdk( sid, ) - # Use SystemPromptPreset for cross-user prompt caching. - # WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when - # excludeDynamicSections=True is in the initialize request AND - # --resume is active. Disable the preset on resumed turns. - # Turn 1 still gets the preset (no --resume). - _cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume + # Use SystemPromptPreset with exclude_dynamic_sections=True on + # every turn — including resumed ones — so all turns share the + # same static prefix and hit the cross-user prompt cache. + # + # Requires CLI ≥ 2.1.98 (older CLIs crash when excludeDynamicSections + # is combined with --resume). claude-agent-sdk >= 0.1.64 bundles + # CLI 2.1.116, so the pin in pyproject.toml is sufficient — no + # external install or env-var override needed. system_prompt_value = _build_system_prompt_value( system_prompt, - cross_user_cache=_cross_user, + cross_user_cache=config.claude_agent_cross_user_prompt_cache, ) sdk_options_kwargs: dict[str, Any] = { @@ -3401,15 +3412,12 @@ async def stream_chat_completion_sdk( # fail with "Session ID already in use". sdk_options_kwargs_retry.pop("resume", None) sdk_options_kwargs_retry.pop("session_id", None) - # Recompute system_prompt for retry — ctx.use_resume may have - # changed (context reduction enabled --resume). CLI 2.1.97 - # crashes when excludeDynamicSections=True is combined with - # --resume, so disable the cross-user preset on resumed turns. - _cross_user_retry = ( - config.claude_agent_cross_user_prompt_cache and not ctx.use_resume - ) + # Recompute system_prompt for retry — the preset is safe on + # every turn (requires CLI ≥ 2.1.98, installed in the Docker + # image and configured via CHAT_CLAUDE_AGENT_CLI_PATH). sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value( - system_prompt, cross_user_cache=_cross_user_retry + system_prompt, + cross_user_cache=config.claude_agent_cross_user_prompt_cache, ) state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs # Retry intentionally omits prior_messages (transcript+gap context) and diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_test.py index f7ebe766f6..d47f67252a 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_test.py @@ -177,70 +177,18 @@ class TestPromptSupplement: assert "## Tool notes" in local_supplement assert "## Tool notes" in e2b_supplement - def test_baseline_supplement_includes_tool_docs(self): - """Baseline mode MUST include tool documentation (direct API needs it).""" - from backend.copilot.prompting import get_baseline_supplement + def test_baseline_supplement_has_shared_notes_no_tool_list(self): + """Baseline now relies on the OpenAI tools array for schemas and only + appends SHARED_TOOL_NOTES (workflow rules not present in any schema). + The old auto-generated ``## AVAILABLE TOOLS`` list is gone — it was + ~4.3K tokens of pure duplication of the tools array.""" + from backend.copilot.prompting import SHARED_TOOL_NOTES - supplement = get_baseline_supplement() - - # MUST have tool list section - assert "## AVAILABLE TOOLS" in supplement - - # Should NOT have environment-specific notes (SDK-only) - assert "## Tool notes" not in supplement - - def test_baseline_supplement_includes_key_tools(self): - """Baseline supplement should document all essential tools.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Core agent workflow tools (always available) - assert "`create_agent`" in docs - assert "`run_agent`" in docs - assert "`find_library_agent`" in docs - assert "`edit_agent`" in docs - - # MCP integration (always available) - assert "`run_mcp_tool`" in docs - - # Folder management (always available) - assert "`create_folder`" in docs - - # Browser tools only if available (Playwright may not be installed in CI) - if ( - TOOL_REGISTRY.get("browser_navigate") - and TOOL_REGISTRY["browser_navigate"].is_available - ): - assert "`browser_navigate`" in docs - - def test_baseline_supplement_includes_workflows(self): - """Baseline supplement should include workflow guidance in tool descriptions.""" - from backend.copilot.prompting import get_baseline_supplement - - docs = get_baseline_supplement() - - # Workflows are now in individual tool descriptions (not separate sections) - # Check that key workflow concepts appear in tool descriptions - assert "agent_json" in docs or "find_block" in docs - assert "run_mcp_tool" in docs - - def test_baseline_supplement_completeness(self): - """All available tools from TOOL_REGISTRY should appear in baseline supplement.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Verify each available registered tool is documented - # (matches _generate_tool_documentation which filters by is_available) - for tool_name, tool in TOOL_REGISTRY.items(): - if not tool.is_available: - continue - assert ( - f"`{tool_name}`" in docs - ), f"Tool '{tool_name}' missing from baseline supplement" + assert "## AVAILABLE TOOLS" not in SHARED_TOOL_NOTES + # Keep the high-value workflow rules that are NOT in any tool schema. + assert "@@agptfile:" in SHARED_TOOL_NOTES + assert "Tool Discovery Priority" in SHARED_TOOL_NOTES + assert "run_sub_session" in SHARED_TOOL_NOTES def test_pause_task_scheduled_before_transcript_upload(self): """Pause is scheduled as a background task before transcript upload begins. @@ -284,21 +232,6 @@ class TestPromptSupplement: # concurrently during upload's first yield. The ordering guarantee is # that create_task is CALLED before upload is AWAITED (see source order). - def test_baseline_supplement_no_duplicate_tools(self): - """No tool should appear multiple times in baseline supplement.""" - from backend.copilot.prompting import get_baseline_supplement - from backend.copilot.tools import TOOL_REGISTRY - - docs = get_baseline_supplement() - - # Count occurrences of each available tool in the entire supplement - for tool_name, tool in TOOL_REGISTRY.items(): - if not tool.is_available: - continue - # Count how many times this tool appears as a bullet point - count = docs.count(f"- **`{tool_name}`**") - assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)" - # --------------------------------------------------------------------------- # _cleanup_sdk_tool_results — orchestration + rate-limiting @@ -700,6 +633,17 @@ class TestSystemPromptPreset: assert result["append"] == "" assert result["exclude_dynamic_sections"] is True + def test_resume_and_fresh_share_the_same_static_prefix(self): + """Every turn (fresh + --resume) must emit the same preset dict + so the cross-user cache prefix match works on all turns. This + relies on CLI ≥ 2.1.98 (installed in the Docker image); older + CLIs would crash on --resume + excludeDynamicSections=True.""" + fresh = _build_system_prompt_value("sys", cross_user_cache=True) + resumed = _build_system_prompt_value("sys", cross_user_cache=True) + assert fresh == resumed + assert isinstance(fresh, dict) + assert fresh.get("exclude_dynamic_sections") is True + def test_default_config_is_enabled(self, _clean_config_env): """The default value for claude_agent_cross_user_prompt_cache is True.""" cfg = cfg_mod.ChatConfig( diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 03c93c286a..a9aafef96f 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "agentmail" @@ -909,18 +909,18 @@ files = [ [[package]] name = "claude-agent-sdk" -version = "0.1.58" +version = "0.1.64" description = "Python SDK for Claude Code" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "claude_agent_sdk-0.1.58-py3-none-macosx_11_0_arm64.whl", hash = "sha256:69197950809754c4f06bba8261f2d99c3f9605b6cc1c13d3409d0eb82fb4ee64"}, - {file = "claude_agent_sdk-0.1.58-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:75d60883fc5e2070bccd8d9b19505fe16af8e049120c03821e9dc8c826cca434"}, - {file = "claude_agent_sdk-0.1.58-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:7bf4eb0f00ec944a7b63eb94788f120dfb0460c348a525235c7d6641805acc1d"}, - {file = "claude_agent_sdk-0.1.58-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:650d298a3d3c0dcdde4b5f1dbf52f472ff0b0ec82987b27ffa2a4e0e72928408"}, - {file = "claude_agent_sdk-0.1.58-py3-none-win_amd64.whl", hash = "sha256:2c2130a7ffe06ed4f88d56b217a5091c91c9bcb1a69cfd94d5dcf0d2946d8c55"}, - {file = "claude_agent_sdk-0.1.58.tar.gz", hash = "sha256:77bee8fd60be033cb870def46c2ab1625a512fa8a3de4ff8d766664ffb16d6a6"}, + {file = "claude_agent_sdk-0.1.64-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4cf47a9e40c0a683a05afff4fac1e3d5ea7965b1e9f72a8e266c8d2efbf65904"}, + {file = "claude_agent_sdk-0.1.64-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:7fe765c6482c74bc6b0b4491ad3bddd1349c25f4cdf4483191c68ea9c1336825"}, + {file = "claude_agent_sdk-0.1.64-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:605eebf46e7590e4f878572c2743954fba3f3530dfd99e10ff3b8b41a9fee757"}, + {file = "claude_agent_sdk-0.1.64-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:bbb1373ee0b4494e2db24aa10d312d22b86895b4b8f18eb5b58f99f14d827237"}, + {file = "claude_agent_sdk-0.1.64-py3-none-win_amd64.whl", hash = "sha256:453fa251e2a4aeed580c72d4c7b2cb98fc8d8d26012798126f5cb11a9829cd71"}, + {file = "claude_agent_sdk-0.1.64.tar.gz", hash = "sha256:147e513cb45095b57c37d74b8d01dd41b5f3ec7f70e408edce43a6590159c27d"}, ] [package.dependencies] @@ -930,6 +930,8 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} [package.extras] dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] +examples = ["asyncpg (>=0.27.0)", "boto3 (>=1.28.0)", "fakeredis (>=2.20.0)", "moto[s3] (>=5.0.0)", "redis (>=4.2.0)"] +otel = ["opentelemetry-api (>=1.20.0)"] [[package]] name = "cleo" @@ -8929,4 +8931,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "c4cc6a0a26869a167ce182b178224554135d89d8ffa4605257d17b3f495cdf59" +content-hash = "529e1acbb1213421ef617f9dab309787cf81ea5d787eeffebc1bd38a42daf976" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index ea81390d81..6e7003a65d 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -18,7 +18,7 @@ apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" } cachetools = "^5.5.0" -claude-agent-sdk = "0.1.58" # latest stable; bundled CLI 2.1.97 -- CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 env var strips the broken context-management beta. See sdk_compat_test.py. +claude-agent-sdk = "^0.1.64" # bundled CLI 2.1.116 -- 2.1.98+ fixes the --resume + excludeDynamicSections crash that used to force a per-turn 33K cache write. CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS=1 env var strips the broken context-management beta. See sdk_compat_test.py. click = "^8.2.0" cryptography = "^46.0" discord-py = "^2.5.2" From 24850e2a3e7ca3a1a06e40005385041f723dfcaf Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 21:05:00 +0700 Subject: [PATCH 07/41] feat(backend/autopilot): stream extended_thinking on baseline via OpenRouter (#12870) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** Fast-mode autopilot never renders a Reasoning block. The frontend already has `ReasoningCollapse` wired up and the wire protocol already carries `StreamReasoning*` events (landed for SDK mode in #12853), but the baseline (OpenRouter OpenAI-compat) path never asks Anthropic for extended thinking and never parses reasoning deltas off the stream. Result: users on fast/standard get a good answer with no visible chain-of-thought, while SDK users see the full Reasoning collapse. **What:** Plumb reasoning end-to-end through the baseline path by opting into OpenRouter's non-OpenAI `reasoning` extension, parsing the reasoning delta fields off each chunk, and emitting the same `StreamReasoningStart/Delta/End` events the SDK adapter already uses. **How:** - **New config:** `baseline_reasoning_max_tokens` (default 8192; 0 disables). Sent as `extra_body={"reasoning": {"max_tokens": N}}` only on Anthropic routes — other providers drop the field, and `is_anthropic_model()` already gates this. - **Delta extraction:** `_extract_reasoning_delta()` handles all three OpenRouter/provider variants in priority order — legacy `delta.reasoning` (string), DeepSeek-style `delta.reasoning_content`, and the structured `delta.reasoning_details` list (text/summary entries; encrypted or unknown entries are skipped). - **Event emission:** Reasoning uses the same state-machine rules the SDK adapter uses — a text delta or tool_use delta arriving mid-stream closes the open reasoning block first, so the AI SDK v5 transport keeps reasoning / text / tool-use as distinct UI parts. On stream end, any still-open reasoning block gets a matching `reasoning-end` so a reasoning-only turn still finalises the frontend collapse. - **Scope:** Live streaming only. Reasoning is not persisted to `ChatMessage` rows or the transcript builder in this PR (SDK path does so via `content_blocks=[{type: 'thinking', ...}]`, but that round-trip requires Anthropic signature plumbing baseline doesn't have today). Reload will still not show reasoning on baseline sessions — can follow up if we decide it's worth the signature handling. ### Changes - `backend/copilot/config.py` — new `baseline_reasoning_max_tokens` field. - `backend/copilot/baseline/service.py` — new `_extract_reasoning_delta()` helper; reasoning block state on `_BaselineStreamState`; `reasoning` gated into `extra_body`; chunk loop emits `StreamReasoning*` events with text/tool_use transition rules; stream-end closes any open reasoning block. - `backend/copilot/baseline/service_unit_test.py` — 11 new tests covering extractor variants (legacy string, deepseek alias, structured list with text/summary aliases, encrypted-skip, empty), paired event ordering (reasoning-end before text-start), reasoning-only streams, and that the `reasoning` request param is correctly gated by model route (Anthropic vs non-Anthropic) and by the config flag. ### Checklist For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [ ] I have tested my changes according to the test plan: - [x] `poetry run pytest backend/copilot/baseline/service_unit_test.py backend/copilot/baseline/transcript_integration_test.py` — 103 passed - [ ] Manual: with `CHAT_USE_CLAUDE_AGENT_SDK=false` and `CHAT_MODEL=anthropic/claude-sonnet-4-6`, send a multi-step prompt on fast mode and confirm a Reasoning collapse appears alongside the final text - [ ] Manual: flip `CHAT_BASELINE_REASONING_MAX_TOKENS=0` and confirm baseline responses revert to text-only (no reasoning param, no reasoning UI) - [ ] Manual: with a non-Anthropic baseline model (`openai/gpt-4o`), confirm the request does NOT include `reasoning` and nothing regresses For configuration changes: - [x] `.env.default` is compatible — new setting falls back to the pydantic default --- .../backend/copilot/baseline/reasoning.py | 230 +++++++++++ .../copilot/baseline/reasoning_test.py | 281 ++++++++++++++ .../backend/copilot/baseline/service.py | 70 +++- .../copilot/baseline/service_unit_test.py | 365 ++++++++++++++++++ .../backend/backend/copilot/config.py | 16 +- .../copilot/sdk/retry_scenarios_test.py | 2 + .../backend/backend/copilot/sdk/service.py | 19 +- 7 files changed, 950 insertions(+), 33 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/baseline/reasoning.py create mode 100644 autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py new file mode 100644 index 0000000000..15a77dde8a --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py @@ -0,0 +1,230 @@ +"""Extended-thinking wire support for the baseline (OpenRouter) path. + +Anthropic routes on OpenRouter expose extended thinking through +non-OpenAI extension fields that the OpenAI Python SDK doesn't model: + +* ``reasoning`` (legacy string) — enabled by ``include_reasoning: true``. +* ``reasoning_content`` — DeepSeek / some OpenRouter routes. +* ``reasoning_details`` — structured list shipped with the unified + ``reasoning`` request param. + +This module keeps the wire-level concerns in one place: + +* :class:`OpenRouterDeltaExtension` validates the extension dict pulled off + ``ChoiceDelta.model_extra`` into typed pydantic models — no ``getattr`` + + ``isinstance`` duck-typing at the call site. +* :class:`BaselineReasoningEmitter` owns the reasoning block lifecycle for + one streaming round and emits ``StreamReasoning*`` events so the caller + only has to plumb the events into its pending queue. +* :func:`reasoning_extra_body` builds the ``extra_body`` fragment for the + OpenAI client call. Returns ``None`` on non-Anthropic routes. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamBaseResponse, + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + +logger = logging.getLogger(__name__) + + +_VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"}) + + +class ReasoningDetail(BaseModel): + """One entry in OpenRouter's ``reasoning_details`` list. + + OpenRouter ships ``type: "reasoning.text"`` / ``"reasoning.summary"`` / + ``"reasoning.encrypted"`` entries. Only the first two carry + user-visible text; encrypted entries are opaque and omitted from the + rendered collapse. Unknown future types are tolerated (``extra="ignore"``) + so an upstream addition doesn't crash the stream — but their ``text`` / + ``summary`` fields are NOT surfaced because they may carry provider + metadata rather than user-visible reasoning (see + :attr:`visible_text`). + """ + + model_config = ConfigDict(extra="ignore") + + type: str | None = None + text: str | None = None + summary: str | None = None + + @property + def visible_text(self) -> str: + """Return the human-readable text for this entry, or ``""``. + + Only entries with a recognised reasoning type (``reasoning.text`` / + ``reasoning.summary``) surface text; unknown or encrypted types + return an empty string even if they carry a ``text`` / + ``summary`` field, to guard against future provider metadata + being rendered as reasoning in the UI. Entries missing a + ``type`` are treated as text (pre-``reasoning_details`` OpenRouter + payloads omit the field). + """ + if self.type is not None and self.type not in _VISIBLE_REASONING_TYPES: + return "" + return self.text or self.summary or "" + + +class OpenRouterDeltaExtension(BaseModel): + """Non-OpenAI fields OpenRouter adds to streaming deltas. + + Instantiate via :meth:`from_delta` which pulls the extension dict off + ``ChoiceDelta.model_extra`` (where pydantic v2 stashes fields that + aren't part of the declared schema) and validates it through this + model. That keeps the parser honest — malformed entries surface as + validation errors rather than silent ``None``-coalesce bugs — and + avoids the ``getattr`` + ``isinstance`` duck-typing the earlier inline + extractor relied on. + """ + + model_config = ConfigDict(extra="ignore") + + reasoning: str | None = None + reasoning_content: str | None = None + reasoning_details: list[ReasoningDetail] = Field(default_factory=list) + + @classmethod + def from_delta(cls, delta: ChoiceDelta) -> "OpenRouterDeltaExtension": + """Build an extension view from ``delta.model_extra``. + + Malformed provider payloads (e.g. ``reasoning_details`` shipped as + a string rather than a list) surface as a ``ValidationError`` which + is logged and swallowed — returning an empty extension so the rest + of the stream (valid text / tool calls) keeps flowing. An optional + feature's corrupted wire data must never abort the whole stream. + """ + try: + return cls.model_validate(delta.model_extra or {}) + except ValidationError as exc: + logger.warning( + "[Baseline] Dropping malformed OpenRouter reasoning payload: %s", + exc, + ) + return cls() + + def visible_text(self) -> str: + """Concatenated reasoning text, pulled from whichever channel is set. + + Priority: the legacy ``reasoning`` string, then DeepSeek's + ``reasoning_content``, then the concatenation of text-bearing + entries in ``reasoning_details``. Only one channel is set per + provider in practice; the priority order just makes the fallback + deterministic if a provider ever emits multiple. + """ + if self.reasoning: + return self.reasoning + if self.reasoning_content: + return self.reasoning_content + return "".join(d.visible_text for d in self.reasoning_details) + + +def reasoning_extra_body(model: str, max_thinking_tokens: int) -> dict[str, Any] | None: + """Build the ``extra_body["reasoning"]`` fragment for the OpenAI client. + + Returns ``None`` for non-Anthropic routes (other OpenRouter providers + ignore the field but we skip it anyway to keep the payload minimal) + and for ``max_thinking_tokens <= 0`` (operator kill switch). + """ + # Imported lazily to avoid pulling service.py at module load — service.py + # imports this module, and the lazy import keeps the dependency one-way. + from backend.copilot.baseline.service import _is_anthropic_model + + if not _is_anthropic_model(model) or max_thinking_tokens <= 0: + return None + return {"reasoning": {"max_tokens": max_thinking_tokens}} + + +class BaselineReasoningEmitter: + """Owns the reasoning block lifecycle for one streaming round. + + Two concerns live here, both driven by the same state machine: + + 1. **Wire events.** The AI SDK v6 wire format pairs every + ``reasoning-start`` with a matching ``reasoning-end`` and treats + reasoning / text / tool-use as distinct UI parts that must not + interleave. + 2. **Session persistence.** ``ChatMessage(role="reasoning")`` rows in + ``session.messages`` are what + ``convertChatSessionToUiMessages.ts`` folds into the assistant + bubble as ``{type: "reasoning"}`` UI parts on reload and on + ``useHydrateOnStreamEnd`` swaps. Without them the live-streamed + reasoning parts get overwritten by the hydrated (reasoning-less) + message list the moment the stream ends. Mirrors the SDK path's + ``acc.reasoning_response`` pattern so both routes render the same + way on reload. + + Pass ``session_messages`` to enable persistence; omit for pure + wire-emission (tests, scratch callers). On first reasoning delta a + fresh ``ChatMessage(role="reasoning")`` is appended and mutated + in-place as further deltas arrive; :meth:`close` drops the reference + but leaves the appended row intact. + """ + + def __init__( + self, + session_messages: list[ChatMessage] | None = None, + ) -> None: + self._block_id: str = str(uuid.uuid4()) + self._open: bool = False + self._session_messages = session_messages + self._current_row: ChatMessage | None = None + + @property + def is_open(self) -> bool: + return self._open + + def on_delta(self, delta: ChoiceDelta) -> list[StreamBaseResponse]: + """Return events for the reasoning text carried by *delta*. + + Empty list when the chunk carries no reasoning payload, so this is + safe to call on every chunk without guarding at the call site. + Persistence (when a session message list is attached) happens in + lockstep with emission so the row's content stays equal to the + concatenated deltas at every delta boundary. + """ + ext = OpenRouterDeltaExtension.from_delta(delta) + text = ext.visible_text() + if not text: + return [] + events: list[StreamBaseResponse] = [] + if not self._open: + events.append(StreamReasoningStart(id=self._block_id)) + self._open = True + if self._session_messages is not None: + self._current_row = ChatMessage(role="reasoning", content="") + self._session_messages.append(self._current_row) + events.append(StreamReasoningDelta(id=self._block_id, delta=text)) + if self._current_row is not None: + self._current_row.content = (self._current_row.content or "") + text + return events + + def close(self) -> list[StreamBaseResponse]: + """Emit ``StreamReasoningEnd`` for the open block (if any) and rotate. + + Idempotent — returns ``[]`` when no block is open. The id rotation + guarantees the next reasoning block starts with a fresh id rather + than reusing one already closed on the wire. The persisted row is + not removed — it stays in ``session_messages`` as the durable + record of what was reasoned. + """ + if not self._open: + return [] + event = StreamReasoningEnd(id=self._block_id) + self._open = False + self._block_id = str(uuid.uuid4()) + self._current_row = None + return [event] diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py new file mode 100644 index 0000000000..df64086d5f --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py @@ -0,0 +1,281 @@ +"""Tests for the baseline reasoning extension module. + +Covers the typed OpenRouter delta parser, the stateful emitter, and the +``extra_body`` builder. The emitter is tested against real +``ChoiceDelta`` pydantic instances so the ``model_extra`` plumbing the +parser relies on is exercised end-to-end. +""" + +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + OpenRouterDeltaExtension, + ReasoningDetail, + reasoning_extra_body, +) +from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, +) + + +def _delta(**extra) -> ChoiceDelta: + """Build a ChoiceDelta with the given extension fields on ``model_extra``.""" + return ChoiceDelta.model_validate({"role": "assistant", **extra}) + + +class TestReasoningDetail: + def test_visible_text_prefers_text(self): + d = ReasoningDetail(type="reasoning.text", text="hi", summary="ignored") + assert d.visible_text == "hi" + + def test_visible_text_falls_back_to_summary(self): + d = ReasoningDetail(type="reasoning.summary", summary="tldr") + assert d.visible_text == "tldr" + + def test_visible_text_empty_for_encrypted(self): + d = ReasoningDetail(type="reasoning.encrypted") + assert d.visible_text == "" + + def test_unknown_fields_are_ignored(self): + # OpenRouter may add new fields in future payloads — they shouldn't + # cause validation errors. + d = ReasoningDetail.model_validate( + {"type": "reasoning.future", "text": "x", "signature": "opaque"} + ) + assert d.text == "x" + + def test_visible_text_empty_for_unknown_type(self): + # Unknown types may carry provider metadata that must not render as + # user-visible reasoning — regardless of whether a text/summary is + # present. Only ``reasoning.text`` / ``reasoning.summary`` surface. + d = ReasoningDetail(type="reasoning.future", text="leaked metadata") + assert d.visible_text == "" + + def test_visible_text_surfaces_text_when_type_missing(self): + # Pre-``reasoning_details`` OpenRouter payloads omit ``type`` — treat + # them as text so we don't regress the legacy structured shape. + d = ReasoningDetail(text="plain") + assert d.visible_text == "plain" + + +class TestOpenRouterDeltaExtension: + def test_from_delta_reads_model_extra(self): + delta = _delta(reasoning="step one") + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning == "step one" + + def test_visible_text_legacy_string(self): + ext = OpenRouterDeltaExtension(reasoning="plain text") + assert ext.visible_text() == "plain text" + + def test_visible_text_deepseek_alias(self): + ext = OpenRouterDeltaExtension(reasoning_content="alt channel") + assert ext.visible_text() == "alt channel" + + def test_visible_text_structured_details_concat(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.text", text="hello "), + ReasoningDetail(type="reasoning.text", text="world"), + ] + ) + assert ext.visible_text() == "hello world" + + def test_visible_text_skips_encrypted(self): + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.encrypted"), + ReasoningDetail(type="reasoning.text", text="visible"), + ] + ) + assert ext.visible_text() == "visible" + + def test_visible_text_empty_when_all_channels_blank(self): + ext = OpenRouterDeltaExtension() + assert ext.visible_text() == "" + + def test_empty_delta_produces_empty_extension(self): + ext = OpenRouterDeltaExtension.from_delta(_delta()) + assert ext.reasoning is None + assert ext.reasoning_content is None + assert ext.reasoning_details == [] + + def test_malformed_reasoning_payload_logged_and_swallowed(self, caplog): + # A malformed payload (e.g. reasoning_details shipped as a string + # rather than a list) must not abort the stream — log it and + # return an empty extension so valid text/tool events keep flowing. + # A plain mock is used here because ``from_delta`` only reads + # ``delta.model_extra`` — avoids reaching into pydantic internals + # (``__pydantic_extra__``) that could be renamed across versions. + from unittest.mock import MagicMock + + delta = MagicMock(spec=ChoiceDelta) + delta.model_extra = {"reasoning_details": "not a list"} + with caplog.at_level("WARNING"): + ext = OpenRouterDeltaExtension.from_delta(delta) + assert ext.reasoning_details == [] + assert ext.visible_text() == "" + assert any("malformed" in r.message.lower() for r in caplog.records) + + def test_unknown_typed_entry_with_text_is_not_surfaced(self): + # Regression: the legacy extractor emitted any entry with a + # ``text`` or ``summary`` field. The typed parser now filters on + # the recognised types so future provider metadata can't leak + # into the reasoning collapse. + ext = OpenRouterDeltaExtension( + reasoning_details=[ + ReasoningDetail(type="reasoning.future", text="provider metadata"), + ReasoningDetail(type="reasoning.text", text="real"), + ] + ) + assert ext.visible_text() == "real" + + +class TestReasoningExtraBody: + def test_anthropic_route_returns_fragment(self): + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 4096) == { + "reasoning": {"max_tokens": 4096} + } + + def test_direct_claude_model_id_still_matches(self): + assert reasoning_extra_body("claude-3-5-sonnet-20241022", 2048) == { + "reasoning": {"max_tokens": 2048} + } + + def test_non_anthropic_route_returns_none(self): + assert reasoning_extra_body("openai/gpt-4o", 4096) is None + assert reasoning_extra_body("google/gemini-2.5-pro", 4096) is None + + def test_zero_max_tokens_kill_switch(self): + # Operator kill switch: ``max_thinking_tokens <= 0`` disables the + # ``reasoning`` extra_body fragment even on an Anthropic route. + # Lets us silence reasoning without dropping the SDK path's budget. + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", 0) is None + assert reasoning_extra_body("anthropic/claude-sonnet-4-6", -1) is None + + +class TestBaselineReasoningEmitter: + def test_first_text_delta_emits_start_then_delta(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="thinking")) + + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningStart) + assert isinstance(events[1], StreamReasoningDelta) + assert events[0].id == events[1].id + assert events[1].delta == "thinking" + assert emitter.is_open is True + + def test_subsequent_deltas_reuse_block_id_without_new_start(self): + emitter = BaselineReasoningEmitter() + first = emitter.on_delta(_delta(reasoning="a")) + second = emitter.on_delta(_delta(reasoning="b")) + + assert any(isinstance(e, StreamReasoningStart) for e in first) + assert all(not isinstance(e, StreamReasoningStart) for e in second) + assert len(second) == 1 + assert isinstance(second[0], StreamReasoningDelta) + assert first[0].id == second[0].id + + def test_empty_delta_emits_nothing(self): + emitter = BaselineReasoningEmitter() + assert emitter.on_delta(_delta(content="hello")) == [] + assert emitter.is_open is False + + def test_close_emits_end_and_rotates_id(self): + emitter = BaselineReasoningEmitter() + # Capture the block id from the wire event rather than reaching + # into emitter internals — the id on the emitted Start/Delta is + # what the frontend actually receives. + start_events = emitter.on_delta(_delta(reasoning="x")) + first_id = start_events[0].id + + events = emitter.close() + assert len(events) == 1 + assert isinstance(events[0], StreamReasoningEnd) + assert events[0].id == first_id + assert emitter.is_open is False + # Next reasoning uses a fresh id. + new_events = emitter.on_delta(_delta(reasoning="y")) + assert isinstance(new_events[0], StreamReasoningStart) + assert new_events[0].id != first_id + + def test_close_is_idempotent(self): + emitter = BaselineReasoningEmitter() + assert emitter.close() == [] + emitter.on_delta(_delta(reasoning="x")) + assert len(emitter.close()) == 1 + assert emitter.close() == [] + + def test_structured_details_round_trip(self): + emitter = BaselineReasoningEmitter() + events = emitter.on_delta( + _delta( + reasoning_details=[ + {"type": "reasoning.text", "text": "plan: "}, + {"type": "reasoning.summary", "summary": "do the thing"}, + ] + ) + ) + deltas = [e for e in events if isinstance(e, StreamReasoningDelta)] + assert len(deltas) == 1 + assert deltas[0].delta == "plan: do the thing" + + +class TestReasoningPersistence: + """The persistence contract: without ``role="reasoning"`` rows in + session.messages, useHydrateOnStreamEnd overwrites the live-streamed + reasoning parts and the Reasoning collapse vanishes. Every delta + must be reflected in the persisted row the moment it's emitted.""" + + def test_session_row_appended_on_first_delta(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + assert session == [] + emitter.on_delta(_delta(reasoning="hi")) + assert len(session) == 1 + assert session[0].role == "reasoning" + assert session[0].content == "hi" + + def test_subsequent_deltas_mutate_same_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="part one ")) + emitter.on_delta(_delta(reasoning="part two")) + + assert len(session) == 1 + assert session[0].content == "part one part two" + + def test_close_keeps_row_in_session(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="thought")) + emitter.close() + + assert len(session) == 1 + assert session[0].content == "thought" + + def test_second_reasoning_block_appends_new_row(self): + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session) + + emitter.on_delta(_delta(reasoning="first")) + emitter.close() + emitter.on_delta(_delta(reasoning="second")) + + assert len(session) == 2 + assert [m.content for m in session] == ["first", "second"] + + def test_no_session_means_no_persistence(self): + """Emitter without attached session list emits wire events only.""" + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="pure wire")) + assert len(events) == 2 # start + delta, no crash + # Nothing else to assert — just proves None session is supported. diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 4e495264c8..f87ec05390 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -27,6 +27,10 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolPara from openai.types.completion_usage import PromptTokensDetails from opentelemetry import trace as otel_trace +from backend.copilot.baseline.reasoning import ( + BaselineReasoningEmitter, + reasoning_extra_body, +) from backend.copilot.config import CopilotLlmModel, CopilotMode from backend.copilot.context import get_workspace_manager, set_execution_context from backend.copilot.graphiti.config import is_enabled_for_user @@ -336,6 +340,7 @@ class _BaselineStreamState: assistant_text: str = "" text_block_id: str = field(default_factory=lambda: str(uuid.uuid4())) text_started: bool = False + reasoning_emitter: BaselineReasoningEmitter = field(init=False) turn_prompt_tokens: int = 0 turn_completion_tokens: int = 0 turn_cache_read_tokens: int = 0 @@ -346,6 +351,10 @@ class _BaselineStreamState: # generate one warning per streaming call. cost_missing_logged: bool = False thinking_stripper: _ThinkingStripper = field(default_factory=_ThinkingStripper) + # MUTATE in place only — ``__post_init__`` hands this list reference to + # ``BaselineReasoningEmitter`` so reasoning rows can be appended as + # deltas stream in. Reassigning (``state.session_messages = [...]``) + # would silently detach the emitter from the new list. session_messages: list[ChatMessage] = field(default_factory=list) # Tracks how much of ``assistant_text`` has already been flushed to # ``session.messages`` via mid-loop pending drains, so the ``finally`` @@ -360,6 +369,14 @@ class _BaselineStreamState: # wasn't a system role, so no marking applies). cached_system_message: dict[str, Any] | None = None + def __post_init__(self) -> None: + # Wire the reasoning emitter to ``session_messages`` so it can + # append ``role="reasoning"`` rows as reasoning streams in — the + # frontend's ``convertChatSessionToUiMessages`` relies on these + # rows to render the Reasoning collapse after the AI SDK's + # stream-end hydrate swaps in the DB-backed message list. + self.reasoning_emitter = BaselineReasoningEmitter(self.session_messages) + def _is_anthropic_model(model: str) -> bool: """Return True if *model* routes to Anthropic (native or via OpenRouter). @@ -536,12 +553,18 @@ async def _baseline_llm_caller( final_messages = messages extra_headers = None typed_messages = cast(list[ChatCompletionMessageParam], final_messages) + extra_body: dict[str, Any] = dict(_OPENROUTER_INCLUDE_USAGE_COST) + reasoning_param = reasoning_extra_body( + state.model, config.claude_agent_max_thinking_tokens + ) + if reasoning_param: + extra_body.update(reasoning_param) create_kwargs: dict[str, Any] = { "model": state.model, "messages": typed_messages, "stream": True, "stream_options": {"include_usage": True}, - "extra_body": _OPENROUTER_INCLUDE_USAGE_COST, + "extra_body": extra_body, } if extra_headers: create_kwargs["extra_headers"] = extra_headers @@ -591,7 +614,14 @@ async def _baseline_llm_caller( if not delta: continue + state.pending_events.extend(state.reasoning_emitter.on_delta(delta)) + if delta.content: + # Text and reasoning must not interleave on the wire — the + # AI SDK maps distinct start/end pairs to distinct UI + # parts. Close any open reasoning block before emitting + # the first text delta of this run. + state.pending_events.extend(state.reasoning_emitter.close()) emit = state.thinking_stripper.process(delta.content) if emit: if not state.text_started: @@ -605,6 +635,10 @@ async def _baseline_llm_caller( ) if delta.tool_calls: + # Same rule as the text branch: close any open reasoning + # block before a tool_use starts so the AI SDK treats + # reasoning and tool-use as distinct parts. + state.pending_events.extend(state.reasoning_emitter.close()) for tc in delta.tool_calls: idx = tc.index if idx not in tool_calls_by_index: @@ -629,6 +663,13 @@ async def _baseline_llm_caller( except Exception: pass + finally: + # Close open blocks on both normal and exception paths so the + # frontend always sees matched start/end pairs. An exception mid + # ``async for chunk in response`` would otherwise leave reasoning + # and/or text unterminated and only ``StreamFinishStep`` emitted — + # the Reasoning / Text collapses would never finalise. + state.pending_events.extend(state.reasoning_emitter.close()) # Flush any buffered text held back by the thinking stripper. tail = state.thinking_stripper.flush() if tail: @@ -639,12 +680,10 @@ async def _baseline_llm_caller( state.pending_events.append( StreamTextDelta(id=state.text_block_id, delta=tail) ) - # Close text block if state.text_started: state.pending_events.append(StreamTextEnd(id=state.text_block_id)) state.text_started = False state.text_block_id = str(uuid.uuid4()) - finally: # Always persist partial text so the session history stays consistent, # even when the stream is interrupted by an exception. state.assistant_text += round_text @@ -1718,25 +1757,14 @@ async def stream_chat_completion_baseline( _stream_error = True error_msg = str(e) or type(e).__name__ logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True) - # Close any open text block. The llm_caller's finally block - # already appended StreamFinishStep to pending_events, so we must - # insert StreamTextEnd *before* StreamFinishStep to preserve the - # protocol ordering: - # StreamStartStep -> StreamTextStart -> ...deltas... -> + # ``_baseline_llm_caller``'s finally block closes any open + # reasoning / text blocks and appends ``StreamFinishStep`` on + # both normal and exception paths, so pending_events already has + # the correct protocol ordering: + # StreamStartStep -> StreamReasoningStart -> ...deltas... -> + # StreamReasoningEnd -> StreamTextStart -> ...deltas... -> # StreamTextEnd -> StreamFinishStep - # Appending (or yielding directly) would place it after - # StreamFinishStep, violating the protocol. - if state.text_started: - # Find the last StreamFinishStep and insert before it. - insert_pos = len(state.pending_events) - for i in range(len(state.pending_events) - 1, -1, -1): - if isinstance(state.pending_events[i], StreamFinishStep): - insert_pos = i - break - state.pending_events.insert( - insert_pos, StreamTextEnd(id=state.text_block_id) - ) - # Drain pending events in correct order + # Just drain what's buffered, then yield the error. for evt in state.pending_events: yield evt state.pending_events.clear() diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index 4e70767426..4092206786 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -23,6 +23,14 @@ from backend.copilot.baseline.service import ( _mark_tools_with_cache_control, ) from backend.copilot.model import ChatMessage +from backend.copilot.response_model import ( + StreamReasoningDelta, + StreamReasoningEnd, + StreamReasoningStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, +) from backend.copilot.transcript_builder import TranscriptBuilder from backend.util.prompt import CompressResult from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult @@ -1508,3 +1516,360 @@ class TestApplyPromptCacheMarkers: # The exact same list object reaches the provider (no copy needed). call_messages = mock_client.chat.completions.create.call_args[1]["messages"] assert call_messages is messages + + +def _make_delta_chunk( + *, + content: str | None = None, + reasoning: str | None = None, + reasoning_details: list | None = None, + reasoning_content: str | None = None, + tool_calls: list | None = None, +): + """Build a streaming chunk with a configurable ``delta`` payload. + + The ``delta`` is a real ``ChoiceDelta`` pydantic instance so OpenRouter + extension fields land on ``delta.model_extra`` — which is how + :class:`OpenRouterDeltaExtension` reads them in production. Using a + raw ``MagicMock`` here would leave ``model_extra`` unset and silently + skip the reasoning parser. ``tool_calls`` (when provided) must be + ``MagicMock`` entries compatible with the service's streaming loop; + they're set on the delta via ``object.__setattr__`` because pydantic + would otherwise reject the non-schema types. + """ + from openai.types.chat.chat_completion_chunk import ChoiceDelta + + payload: dict = {"role": "assistant"} + if content is not None: + payload["content"] = content + if reasoning is not None: + payload["reasoning"] = reasoning + if reasoning_content is not None: + payload["reasoning_content"] = reasoning_content + if reasoning_details is not None: + payload["reasoning_details"] = reasoning_details + delta = ChoiceDelta.model_validate(payload) + # ChoiceDelta's tool_calls schema expects OpenAI-typed entries; bypass + # validation so tests can use MagicMocks that mimic the streaming shape. + if tool_calls is not None: + object.__setattr__(delta, "tool_calls", tool_calls) + + chunk = MagicMock() + chunk.usage = None + choice = MagicMock() + choice.delta = delta + chunk.choices = [choice] + return chunk + + +def _make_tool_call_delta(*, index: int, call_id: str, name: str, arguments: str): + """Build a ``delta.tool_calls[i]`` entry for streaming tool-use.""" + tc = MagicMock() + tc.index = index + tc.id = call_id + function = MagicMock() + function.name = name + function.arguments = arguments + tc.function = function + return tc + + +class TestBaselineReasoningStreaming: + """End-to-end reasoning event emission through ``_baseline_llm_caller``.""" + + @pytest.mark.asyncio + async def test_reasoning_then_text_emits_paired_events(self): + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="thinking..."), + _make_delta_chunk(reasoning=" more"), + _make_delta_chunk(content="final answer"), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningDelta" in types + assert "StreamReasoningEnd" in types + + # Reasoning must close before text opens — AI SDK v5 rejects + # interleaved reasoning / text parts. + reason_end = types.index("StreamReasoningEnd") + text_start = types.index("StreamTextStart") + assert reason_end < text_start + + # All reasoning deltas share a single block id; the text block uses + # a fresh id after the reasoning-end rotation. + reasoning_ids = { + e.id + for e in state.pending_events + if isinstance( + e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd) + ) + } + text_ids = { + e.id + for e in state.pending_events + if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd)) + } + assert len(reasoning_ids) == 1 + assert len(text_ids) == 1 + assert reasoning_ids.isdisjoint(text_ids) + + combined = "".join( + e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta) + ) + assert combined == "thinking... more" + + @pytest.mark.asyncio + async def test_reasoning_then_tool_call_closes_reasoning_first(self): + """A tool_call arriving mid-reasoning must close the reasoning block + before the tool-use is flushed — AI SDK v5 treats reasoning and + tool-use as distinct UI parts and rejects interleaving.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + chunks = [ + _make_delta_chunk(reasoning="deliberating..."), + _make_delta_chunk( + tool_calls=[ + _make_tool_call_delta( + index=0, + call_id="call_1", + name="search", + arguments='{"q":"x"}', + ) + ], + ), + ] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock(*chunks) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + response = await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + # A reasoning-end must have been emitted — this is the tool_calls + # branch's responsibility, not the stream-end cleanup. + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + + # The tool_call was collected — confirms the tool-use path executed + # after reasoning closed (rather than silently dropping the tool). + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "search" + + # No text events — this stream had no content deltas. + assert "StreamTextStart" not in types + + @pytest.mark.asyncio + async def test_reasoning_closed_on_mid_stream_exception(self): + """Regression guard: an exception during the streaming loop must + still emit ``StreamReasoningEnd`` (and ``StreamTextEnd`` when a + text block is open) before ``StreamFinishStep`` — the frontend + collapse relies on matched start/end pairs, and the outer handler + no longer patches these after-the-fact.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + async def failing_stream(): + yield _make_delta_chunk(reasoning="thinking...") + raise RuntimeError("boom") + + stream = MagicMock() + stream.close = AsyncMock() + stream.__aiter__ = lambda self: failing_stream() + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=stream) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + with pytest.raises(RuntimeError): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + # The reasoning block was opened, the exception fired, and the + # finally block must have closed it before emitting the finish + # step. + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + assert "StreamFinishStep" in types + assert types.index("StreamReasoningEnd") < types.index("StreamFinishStep") + # Emitter is reset so a retried round starts with fresh ids. + assert state.reasoning_emitter.is_open is False + + @pytest.mark.asyncio + async def test_reasoning_param_sent_on_anthropic_routes(self): + """Anthropic route gets ``reasoning.max_tokens`` on the request.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" in extra_body + assert extra_body["reasoning"]["max_tokens"] > 0 + + @pytest.mark.asyncio + async def test_reasoning_param_absent_on_non_anthropic_routes(self): + """Non-Anthropic routes (e.g. OpenAI) must not receive ``reasoning``.""" + state = _BaselineStreamState(model="openai/gpt-4o") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_reasoning_only_stream_still_closes_block(self): + """Regression: a stream with only reasoning (no text, no tool_call) + must still emit a matching ``reasoning-end`` at stream close so the + frontend Reasoning collapse finalises. Exercised here against + ``_baseline_llm_caller`` to cover the emitter's integration with + the finally-block, not just the unit emitter in reasoning_test.py. + """ + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="just thinking"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + types = [type(e).__name__ for e in state.pending_events] + assert "StreamReasoningStart" in types + assert "StreamReasoningEnd" in types + # No text was produced — no text events should be emitted. + assert "StreamTextStart" not in types + assert "StreamTextDelta" not in types + + @pytest.mark.asyncio + async def test_reasoning_param_suppressed_when_thinking_tokens_zero(self): + """Operator kill switch: setting ``claude_agent_max_thinking_tokens`` + to 0 removes the ``reasoning`` fragment from ``extra_body`` even on + an Anthropic route. Restores the zero-disables behaviour the old + ``baseline_reasoning_max_tokens`` config used to provide.""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock() + ) + + with ( + patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ), + patch( + "backend.copilot.baseline.service.config.claude_agent_max_thinking_tokens", + 0, + ), + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + extra_body = mock_client.chat.completions.create.call_args[1]["extra_body"] + assert "reasoning" not in extra_body + + @pytest.mark.asyncio + async def test_reasoning_persists_to_state_session_messages(self): + """Integration guard: ``_BaselineStreamState.__post_init__`` wires + the emitter to ``state.session_messages``, so reasoning deltas + flowing through ``_baseline_llm_caller`` must produce a + ``role="reasoning"`` row on the state's session list. Catches + regressions where the wiring silently breaks (e.g. a refactor + passes the wrong list reference).""" + state = _BaselineStreamState(model="anthropic/claude-sonnet-4-6") + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_stream_mock( + _make_delta_chunk(reasoning="first "), + _make_delta_chunk(reasoning="thought"), + _make_delta_chunk(content="answer"), + ) + ) + + with patch( + "backend.copilot.baseline.service._get_openai_client", + return_value=mock_client, + ): + await _baseline_llm_caller( + messages=[{"role": "user", "content": "hi"}], + tools=[], + state=state, + ) + + reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"] + assert len(reasoning_rows) == 1 + assert reasoning_rows[0].content == "first thought" diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index 1080921fd8..1bb63fe1da 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -192,12 +192,18 @@ class ChatConfig(BaseSettings): ) claude_agent_max_thinking_tokens: int = Field( default=8192, - ge=1024, + ge=0, le=128000, - description="Maximum thinking/reasoning tokens per LLM call. " - "Extended thinking on Opus can generate 50k+ tokens at $75/M — " - "capping this is the single biggest cost lever. " - "8192 is sufficient for most tasks; increase for complex reasoning.", + description="Maximum thinking/reasoning tokens per LLM call. Applies " + "to both the Claude Agent SDK path (as ``max_thinking_tokens``) and " + "the baseline OpenRouter path (as ``extra_body.reasoning.max_tokens`` " + "on Anthropic routes). Extended thinking on Opus can generate 50k+ " + "tokens at $75/M — capping this is the single biggest cost lever. " + "8192 is sufficient for most tasks; increase for complex reasoning. " + "Set to 0 to disable extended thinking on both paths (kill switch): " + "baseline skips the ``reasoning`` extra_body; SDK omits the " + "``max_thinking_tokens`` kwarg so the CLI falls back to model default " + "(which, without the flag, leaves extended thinking off).", ) claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = ( Field( diff --git a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py index 5b3919c2aa..d774637ed5 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/retry_scenarios_test.py @@ -1036,6 +1036,8 @@ def _make_sdk_patches( claude_agent_max_transient_retries=1, claude_agent_max_turns=1000, claude_agent_max_budget_usd=100.0, + claude_agent_max_thinking_tokens=0, + claude_agent_thinking_effort=None, claude_agent_fallback_model=None, ), ), diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 8fe8aa12df..325d4271ac 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -3076,14 +3076,19 @@ async def stream_chat_completion_sdk( "max_turns": config.claude_agent_max_turns, # max_budget_usd: per-query spend ceiling enforced by the CLI. "max_budget_usd": config.claude_agent_max_budget_usd, - # max_thinking_tokens: cap extended thinking output per LLM call. - # Thinking tokens are billed at output rate ($75/M for Opus) and - # account for ~54% of total cost. 8192 is the default. - # Intentionally sent for all models including Sonnet — the CLI - # silently ignores this field for non-Opus models (those without - # native extended thinking), so it is safe to pass unconditionally. - "max_thinking_tokens": config.claude_agent_max_thinking_tokens, } + # max_thinking_tokens: cap extended thinking output per LLM call. + # Thinking tokens are billed at output rate ($75/M for Opus) and + # account for ~54% of total cost. 8192 is the default. + # Intentionally sent for all models including Sonnet — the CLI + # silently ignores this field for non-Opus models (those without + # native extended thinking), so it is safe to pass unconditionally. + # Setting to 0 acts as the kill switch (same as baseline): omit the + # kwarg so the CLI falls back to its default (extended thinking off). + if config.claude_agent_max_thinking_tokens > 0: + sdk_options_kwargs["max_thinking_tokens"] = ( + config.claude_agent_max_thinking_tokens + ) # effort: only set for models with extended thinking (Opus). # Setting effort on Sonnet causes tag leaks. if config.claude_agent_thinking_effort: From 38c2844b83ce821bd4dbdfc765bfdf45b735fcd2 Mon Sep 17 00:00:00 2001 From: Nicholas Tindle Date: Tue, 21 Apr 2026 10:28:44 -0500 Subject: [PATCH 08/41] feat(admin): Add system diagnostics and execution management dashboard (#11235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes 🏗️ This PR adds a comprehensive admin diagnostics dashboard for monitoring system health and managing running executions. https://github.com/user-attachments/assets/f7afa3ed-63d8-4b5c-85e4-8756d9e3879e #### Backend Changes: - **New data layer** (backend/data/diagnostics.py): Created a dedicated diagnostics module following the established data layer pattern - get_execution_diagnostics() - Retrieves execution metrics (running, queued, completed counts) - get_agent_diagnostics() - Fetches agent-related metrics - get_running_executions_details() - Lists all running executions with detailed info - stop_execution() and stop_executions_bulk() - Admin controls for stopping executions - **Admin API endpoints** (backend/server/v2/admin/diagnostics_admin_routes.py): - GET /admin/diagnostics/executions - Execution status metrics - GET /admin/diagnostics/agents - Agent utilization metrics - GET /admin/diagnostics/executions/running - Paginated list of running executions - POST /admin/diagnostics/executions/stop - Stop single execution - POST /admin/diagnostics/executions/stop-bulk - Stop multiple executions - All endpoints secured with admin-only access #### Frontend Changes: - **Diagnostics Dashboard** (frontend/src/app/(platform)/admin/diagnostics/page.tsx): - Real-time system metrics display (running, queued, completed executions) - RabbitMQ queue depth monitoring - Agent utilization statistics - Auto-refresh every 30 seconds - **Execution Management Table** (frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx): - Displays running executions with: ID, Agent Name, Version, User Email/ID, Status, Start Time - Multi-select functionality with checkboxes - Individual stop buttons for each execution - "Stop Selected" and "Stop All" bulk actions - Confirmation dialogs for safety - Pagination for handling large datasets - Toast notifications for user feedback #### Security: - All admin endpoints properly secured with requires_admin_user decorator - Frontend routes protected with role-based access controls - Admin navigation link only visible to admin users ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Verified admin-only access to diagnostics page - [x] Tested execution metrics display and auto-refresh - [x] Confirmed RabbitMQ queue depth monitoring works - [x] Tested stopping individual executions - [x] Tested bulk stop operations with multi-select - [x] Verified pagination works for large datasets - [x] Confirmed toast notifications appear for all actions #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes (no changes needed) - [x] `docker-compose.yml` is updated or already compatible with my changes (no changes needed) - [x] I have included a list of my configuration changes in the PR description (no config changes required) --- > [!NOTE] > **Medium Risk** > Adds new admin-only endpoints that can stop, requeue, and bulk-mark executions as `FAILED`, plus schedule deletion, which can directly impact production workload and data integrity if misused or buggy. > > **Overview** > Introduces a **System Diagnostics** admin feature spanning backend + frontend to monitor execution/schedule health and perform remediation actions. > > On the backend, adds a new `backend/data/diagnostics.py` data layer and `diagnostics_admin_routes.py` with admin-secured endpoints to fetch execution/agent/schedule metrics (including RabbitMQ queue depths and invalid-state detection), list problem executions/schedules, and perform bulk operations like `stop`, `requeue`, and `cleanup` (marking orphaned/stuck items as `FAILED` or deleting orphaned schedules). It also extends `get_graph_executions`/`get_graph_executions_count` with `execution_ids` filtering, pagination, started/updated time filters, and configurable ordering to support efficient bulk/admin queries. > > On the frontend, adds an admin diagnostics page with summary cards and tables for executions and schedules (tabs for orphaned/failed/long-running/stuck-queued/invalid, plus confirmation dialogs for destructive actions), wires it into admin navigation, and adds comprehensive unit tests for both the new API routes and UI behavior. > > Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 15b9ed26f9c39d5d79ad74ab66245bba79df0f01. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot). --------- Co-authored-by: Claude Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Nicholas Tindle Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../admin/diagnostics_admin_routes.py | 932 ++++++++++++ .../admin/diagnostics_admin_routes_test.py | 889 ++++++++++++ .../backend/api/features/admin/model.py | 67 + .../backend/backend/api/rest_api.py | 6 + .../backend/backend/data/diagnostics.py | 1215 ++++++++++++++++ .../backend/backend/data/diagnostics_test.py | 464 ++++++ .../backend/backend/data/execution.py | 68 +- .../backend/backend/executor/utils.py | 6 +- .../admin/__tests__/layout.test.tsx | 53 + .../__tests__/DiagnosticsContent.test.tsx | 540 +++++++ .../__tests__/ExecutionsTable.test.tsx | 1258 +++++++++++++++++ .../__tests__/SchedulesTable.test.tsx | 413 ++++++ .../admin/diagnostics/__tests__/page.test.tsx | 133 ++ .../components/DiagnosticsContent.tsx | 579 ++++++++ .../components/ExecutionsTable.tsx | 1079 ++++++++++++++ .../diagnostics/components/SchedulesTable.tsx | 455 ++++++ .../components/useDiagnosticsContent.ts | 63 + .../app/(platform)/admin/diagnostics/page.tsx | 17 + .../src/app/(platform)/admin/layout.tsx | 6 + .../frontend/src/app/api/openapi.json | 1225 ++++++++++++++++ 20 files changed, 9465 insertions(+), 3 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py create mode 100644 autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py create mode 100644 autogpt_platform/backend/backend/data/diagnostics.py create mode 100644 autogpt_platform/backend/backend/data/diagnostics_test.py create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/__tests__/layout.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/DiagnosticsContent.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/ExecutionsTable.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/SchedulesTable.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/page.test.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/DiagnosticsContent.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/SchedulesTable.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/useDiagnosticsContent.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/page.tsx diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py new file mode 100644 index 0000000000..4cb8ff0729 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes.py @@ -0,0 +1,932 @@ +import asyncio +import logging +from typing import List + +from autogpt_libs.auth import requires_admin_user +from autogpt_libs.auth.models import User as AuthUser +from fastapi import APIRouter, HTTPException, Security +from prisma.enums import AgentExecutionStatus +from pydantic import BaseModel + +from backend.api.features.admin.model import ( + AgentDiagnosticsResponse, + ExecutionDiagnosticsResponse, +) +from backend.data.diagnostics import ( + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, + cleanup_all_stuck_queued_executions, + cleanup_orphaned_executions_bulk, + cleanup_orphaned_schedules_bulk, + get_agent_diagnostics, + get_all_orphaned_execution_ids, + get_all_schedules_details, + get_all_stuck_queued_execution_ids, + get_execution_diagnostics, + get_failed_executions_count, + get_failed_executions_details, + get_invalid_executions_details, + get_long_running_executions_details, + get_orphaned_executions_details, + get_orphaned_schedules_details, + get_running_executions_details, + get_schedule_health_metrics, + get_stuck_queued_executions_details, + stop_all_long_running_executions, +) +from backend.data.execution import get_graph_executions +from backend.executor.utils import add_graph_execution, stop_graph_execution + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/admin", + tags=["diagnostics", "admin"], + dependencies=[Security(requires_admin_user)], +) + + +class RunningExecutionsListResponse(BaseModel): + """Response model for list of running executions""" + + executions: List[RunningExecutionDetail] + total: int + + +class FailedExecutionsListResponse(BaseModel): + """Response model for list of failed executions""" + + executions: List[FailedExecutionDetail] + total: int + + +class StopExecutionRequest(BaseModel): + """Request model for stopping a single execution""" + + execution_id: str + + +class StopExecutionsRequest(BaseModel): + """Request model for stopping multiple executions""" + + execution_ids: List[str] + + +class StopExecutionResponse(BaseModel): + """Response model for stop execution operations""" + + success: bool + stopped_count: int = 0 + message: str + + +class RequeueExecutionResponse(BaseModel): + """Response model for requeue execution operations""" + + success: bool + requeued_count: int = 0 + message: str + + +@router.get( + "/diagnostics/executions", + response_model=ExecutionDiagnosticsResponse, + summary="Get Execution Diagnostics", +) +async def get_execution_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about execution status. + + Returns all execution metrics including: + - Current state (running, queued) + - Orphaned executions (>24h old, likely not in executor) + - Failure metrics (1h, 24h, rate) + - Long-running detection (stuck >1h, >24h) + - Stuck queued detection + - Throughput metrics (completions/hour) + - RabbitMQ queue depths + """ + logger.info("Getting execution diagnostics") + + diagnostics = await get_execution_diagnostics() + + response = ExecutionDiagnosticsResponse( + running_executions=diagnostics.running_count, + queued_executions_db=diagnostics.queued_db_count, + queued_executions_rabbitmq=diagnostics.rabbitmq_queue_depth, + cancel_queue_depth=diagnostics.cancel_queue_depth, + orphaned_running=diagnostics.orphaned_running, + orphaned_queued=diagnostics.orphaned_queued, + failed_count_1h=diagnostics.failed_count_1h, + failed_count_24h=diagnostics.failed_count_24h, + failure_rate_24h=diagnostics.failure_rate_24h, + stuck_running_24h=diagnostics.stuck_running_24h, + stuck_running_1h=diagnostics.stuck_running_1h, + oldest_running_hours=diagnostics.oldest_running_hours, + stuck_queued_1h=diagnostics.stuck_queued_1h, + queued_never_started=diagnostics.queued_never_started, + invalid_queued_with_start=diagnostics.invalid_queued_with_start, + invalid_running_without_start=diagnostics.invalid_running_without_start, + completed_1h=diagnostics.completed_1h, + completed_24h=diagnostics.completed_24h, + throughput_per_hour=diagnostics.throughput_per_hour, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Execution diagnostics: running={diagnostics.running_count}, " + f"queued_db={diagnostics.queued_db_count}, " + f"orphaned={diagnostics.orphaned_running + diagnostics.orphaned_queued}, " + f"failed_24h={diagnostics.failed_count_24h}" + ) + + return response + + +@router.get( + "/diagnostics/agents", + response_model=AgentDiagnosticsResponse, + summary="Get Agent Diagnostics", +) +async def get_agent_diagnostics_endpoint(): + """ + Get diagnostic information about agents. + + Returns: + - agents_with_active_executions: Number of unique agents with running/queued executions + - timestamp: Current timestamp + """ + logger.info("Getting agent diagnostics") + + diagnostics = await get_agent_diagnostics() + + response = AgentDiagnosticsResponse( + agents_with_active_executions=diagnostics.agents_with_active_executions, + timestamp=diagnostics.timestamp, + ) + + logger.info( + f"Agent diagnostics: with_active_executions={diagnostics.agents_with_active_executions}" + ) + + return response + + +@router.get( + "/diagnostics/executions/running", + response_model=RunningExecutionsListResponse, + summary="List Running Executions", +) +async def list_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of running and queued executions (recent, likely active). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of running executions with details + """ + logger.info(f"Listing running executions (limit={limit}, offset={offset})") + + executions = await get_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.running_count + diagnostics.queued_db_count + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/orphaned", + response_model=RunningExecutionsListResponse, + summary="List Orphaned Executions", +) +async def list_orphaned_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of orphaned executions (>24h old, likely not in executor). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of orphaned executions with details + """ + logger.info(f"Listing orphaned executions (limit={limit}, offset={offset})") + + executions = await get_orphaned_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.orphaned_running + diagnostics.orphaned_queued + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/failed", + response_model=FailedExecutionsListResponse, + summary="List Failed Executions", +) +async def list_failed_executions( + limit: int = 100, + offset: int = 0, + hours: int = 24, +): + """ + Get detailed list of failed executions. + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + hours: Number of hours to look back (default 24) + + Returns: + List of failed executions with error details + """ + logger.info( + f"Listing failed executions (limit={limit}, offset={offset}, hours={hours})" + ) + + executions = await get_failed_executions_details( + limit=limit, offset=offset, hours=hours + ) + + # Get total count for pagination + # Always count actual total for given hours parameter + total = await get_failed_executions_count(hours=hours) + + return FailedExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/long-running", + response_model=RunningExecutionsListResponse, + summary="List Long-Running Executions", +) +async def list_long_running_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of long-running executions (RUNNING status >24h). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of long-running executions with details + """ + logger.info(f"Listing long-running executions (limit={limit}, offset={offset})") + + executions = await get_long_running_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_running_24h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/stuck-queued", + response_model=RunningExecutionsListResponse, + summary="List Stuck Queued Executions", +) +async def list_stuck_queued_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of stuck queued executions (QUEUED >1h, never started). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of stuck queued executions with details + """ + logger.info(f"Listing stuck queued executions (limit={limit}, offset={offset})") + + executions = await get_stuck_queued_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = diagnostics.stuck_queued_1h + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.get( + "/diagnostics/executions/invalid", + response_model=RunningExecutionsListResponse, + summary="List Invalid Executions", +) +async def list_invalid_executions( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of executions in invalid states (READ-ONLY). + + Invalid states indicate data corruption and require manual investigation: + - QUEUED but has startedAt (impossible - can't start while queued) + - RUNNING but no startedAt (impossible - can't run without starting) + + ⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation. + + Each invalid execution likely has a different root cause (crashes, race conditions, + DB corruption). Investigate the execution history and logs to determine appropriate + action (manual cleanup, status fix, or leave as-is if system recovered). + + Args: + limit: Maximum number of executions to return (default 100) + offset: Number of executions to skip (default 0) + + Returns: + List of invalid state executions with details + """ + logger.info(f"Listing invalid state executions (limit={limit}, offset={offset})") + + executions = await get_invalid_executions_details(limit=limit, offset=offset) + + # Get total count for pagination + diagnostics = await get_execution_diagnostics() + total = ( + diagnostics.invalid_queued_with_start + + diagnostics.invalid_running_without_start + ) + + return RunningExecutionsListResponse(executions=executions, total=total) + + +@router.post( + "/diagnostics/executions/requeue", + response_model=RequeueExecutionResponse, + summary="Requeue Stuck Execution", +) +async def requeue_single_execution( + request: StopExecutionRequest, # Reuse same request model (has execution_id) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue a stuck QUEUED execution (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains execution_id to requeue + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} requeueing execution {request.execution_id}") + + # Get the execution (validation - must be QUEUED) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + raise HTTPException( + status_code=404, + detail="Execution not found or not in QUEUED status", + ) + + execution = executions[0] + + # Use add_graph_execution in requeue mode + await add_graph_execution( + graph_id=execution.graph_id, + user_id=execution.user_id, + graph_version=execution.graph_version, + graph_exec_id=request.execution_id, # Requeue existing execution + ) + + return RequeueExecutionResponse( + success=True, + requeued_count=1, + message="Execution requeued successfully", + ) + + +@router.post( + "/diagnostics/executions/requeue-bulk", + response_model=RequeueExecutionResponse, + summary="Requeue Multiple Stuck Executions", +) +async def requeue_multiple_executions( + request: StopExecutionsRequest, # Reuse same request model (has execution_ids) + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue multiple stuck QUEUED executions (admin only). + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits. + + Args: + request: Contains list of execution_ids to requeue + + Returns: + Number of executions requeued and success message + """ + logger.info( + f"Admin {user.user_id} requeueing {len(request.execution_ids)} executions" + ) + + # Get executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=request.execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + if not executions: + return RequeueExecutionResponse( + success=False, + requeued_count=0, + message="No QUEUED executions found to requeue", + ) + + # Requeue all executions in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/stop", + response_model=StopExecutionResponse, + summary="Stop Single Execution", +) +async def stop_single_execution( + request: StopExecutionRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop a single execution (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains execution_id to stop + + Returns: + Success status and message + """ + logger.info(f"Admin {user.user_id} stopping execution {request.execution_id}") + + # Get the execution to find its owner user_id (required by stop_graph_execution) + executions = await get_graph_executions( + graph_exec_id=request.execution_id, + ) + + if not executions: + raise HTTPException(status_code=404, detail="Execution not found") + + execution = executions[0] + + # Use robust stop_graph_execution (cascades to children, waits for termination) + await stop_graph_execution( + user_id=execution.user_id, + graph_exec_id=request.execution_id, + wait_timeout=15.0, + cascade=True, + ) + + return StopExecutionResponse( + success=True, + stopped_count=1, + message="Execution stopped successfully", + ) + + +@router.post( + "/diagnostics/executions/stop-bulk", + response_model=StopExecutionResponse, + summary="Stop Multiple Executions", +) +async def stop_multiple_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Stop multiple active executions (admin only). + + Uses robust stop_graph_execution which cascades to children and waits for termination. + + Args: + request: Contains list of execution_ids to stop + + Returns: + Number of executions stopped and success message + """ + + logger.info( + f"Admin {user.user_id} stopping {len(request.execution_ids)} executions" + ) + + # Get executions by ID list + executions = await get_graph_executions( + execution_ids=request.execution_ids, + ) + + if not executions: + return StopExecutionResponse( + success=False, + stopped_count=0, + message="No executions found", + ) + + # Stop all executions in parallel using robust stop_graph_execution + async def stop_one(exec) -> bool: + try: + await stop_graph_execution( + user_id=exec.user_id, + graph_exec_id=exec.id, + wait_timeout=15.0, + cascade=True, + ) + return True + except Exception as e: + logger.error(f"Failed to stop execution {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[stop_one(exec) for exec in executions], return_exceptions=False + ) + + stopped_count = sum(1 for success in results if success) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} of {len(request.execution_ids)} executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup Orphaned Executions", +) +async def cleanup_orphaned_executions( + request: StopExecutionsRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned executions by directly updating DB status (admin only). + For executions in DB but not actually running in executor (old/stale records). + + Args: + request: Contains list of execution_ids to cleanup + + Returns: + Number of executions cleaned up and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.execution_ids)} orphaned executions" + ) + + cleaned_count = await cleanup_orphaned_executions_bulk( + request.execution_ids, user.user_id + ) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} of {len(request.execution_ids)} orphaned executions", + ) + + +# ============================================================================ +# SCHEDULE DIAGNOSTICS ENDPOINTS +# ============================================================================ + + +class SchedulesListResponse(BaseModel): + """Response model for list of schedules""" + + schedules: List[ScheduleDetail] + total: int + + +class OrphanedSchedulesListResponse(BaseModel): + """Response model for list of orphaned schedules""" + + schedules: List[OrphanedScheduleDetail] + total: int + + +class ScheduleCleanupRequest(BaseModel): + """Request model for cleaning up schedules""" + + schedule_ids: List[str] + + +class ScheduleCleanupResponse(BaseModel): + """Response model for schedule cleanup operations""" + + success: bool + deleted_count: int = 0 + message: str + + +@router.get( + "/diagnostics/schedules", + response_model=ScheduleHealthMetrics, + summary="Get Schedule Diagnostics", +) +async def get_schedule_diagnostics_endpoint(): + """ + Get comprehensive diagnostic information about schedule health. + + Returns schedule metrics including: + - Total schedules (user vs system) + - Orphaned schedules by category + - Upcoming executions + """ + logger.info("Getting schedule diagnostics") + + diagnostics = await get_schedule_health_metrics() + + logger.info( + f"Schedule diagnostics: total={diagnostics.total_schedules}, " + f"user={diagnostics.user_schedules}, " + f"orphaned={diagnostics.total_orphaned}" + ) + + return diagnostics + + +@router.get( + "/diagnostics/schedules/all", + response_model=SchedulesListResponse, + summary="List All User Schedules", +) +async def list_all_schedules( + limit: int = 100, + offset: int = 0, +): + """ + Get detailed list of all user schedules (excludes system monitoring jobs). + + Args: + limit: Maximum number of schedules to return (default 100) + offset: Number of schedules to skip (default 0) + + Returns: + List of schedules with details + """ + logger.info(f"Listing all schedules (limit={limit}, offset={offset})") + + schedules = await get_all_schedules_details(limit=limit, offset=offset) + + # Get total count + diagnostics = await get_schedule_health_metrics() + total = diagnostics.user_schedules + + return SchedulesListResponse(schedules=schedules, total=total) + + +@router.get( + "/diagnostics/schedules/orphaned", + response_model=OrphanedSchedulesListResponse, + summary="List Orphaned Schedules", +) +async def list_orphaned_schedules(): + """ + Get detailed list of orphaned schedules with orphan reasons. + + Returns: + List of orphaned schedules categorized by orphan type + """ + logger.info("Listing orphaned schedules") + + schedules = await get_orphaned_schedules_details() + + return OrphanedSchedulesListResponse(schedules=schedules, total=len(schedules)) + + +@router.post( + "/diagnostics/schedules/cleanup-orphaned", + response_model=ScheduleCleanupResponse, + summary="Cleanup Orphaned Schedules", +) +async def cleanup_orphaned_schedules( + request: ScheduleCleanupRequest, + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup orphaned schedules by deleting from scheduler (admin only). + + Args: + request: Contains list of schedule_ids to delete + + Returns: + Number of schedules deleted and success message + """ + logger.info( + f"Admin {user.user_id} cleaning up {len(request.schedule_ids)} orphaned schedules" + ) + + deleted_count = await cleanup_orphaned_schedules_bulk( + request.schedule_ids, user.user_id + ) + + return ScheduleCleanupResponse( + success=deleted_count > 0, + deleted_count=deleted_count, + message=f"Deleted {deleted_count} of {len(request.schedule_ids)} orphaned schedules", + ) + + +@router.post( + "/diagnostics/executions/stop-all-long-running", + response_model=StopExecutionResponse, + summary="Stop ALL Long-Running Executions", +) +async def stop_all_long_running_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions stopped and success message + """ + logger.info(f"Admin {user.user_id} stopping ALL long-running executions") + + stopped_count = await stop_all_long_running_executions(user.user_id) + + return StopExecutionResponse( + success=stopped_count > 0, + stopped_count=stopped_count, + message=f"Stopped {stopped_count} long-running executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-orphaned", + response_model=StopExecutionResponse, + summary="Cleanup ALL Orphaned Executions", +) +async def cleanup_all_orphaned_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL orphaned executions (>24h old) by directly updating DB status. + Operates on all executions, not just paginated results. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL orphaned executions") + + # Fetch all orphaned execution IDs + execution_ids = await get_all_orphaned_execution_ids() + + if not execution_ids: + return StopExecutionResponse( + success=True, + stopped_count=0, + message="No orphaned executions to cleanup", + ) + + cleaned_count = await cleanup_orphaned_executions_bulk(execution_ids, user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} orphaned executions", + ) + + +@router.post( + "/diagnostics/executions/cleanup-all-stuck-queued", + response_model=StopExecutionResponse, + summary="Cleanup ALL Stuck Queued Executions", +) +async def cleanup_all_stuck_queued_executions_endpoint( + user: AuthUser = Security(requires_admin_user), +): + """ + Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only). + Operates on entire dataset, not limited to pagination. + + Returns: + Number of executions cleaned up and success message + """ + logger.info(f"Admin {user.user_id} cleaning up ALL stuck queued executions") + + cleaned_count = await cleanup_all_stuck_queued_executions(user.user_id) + + return StopExecutionResponse( + success=cleaned_count > 0, + stopped_count=cleaned_count, + message=f"Cleaned up {cleaned_count} stuck queued executions", + ) + + +@router.post( + "/diagnostics/executions/requeue-all-stuck", + response_model=RequeueExecutionResponse, + summary="Requeue ALL Stuck Queued Executions", +) +async def requeue_all_stuck_executions( + user: AuthUser = Security(requires_admin_user), +): + """ + Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ. + Operates on all executions, not just paginated results. + + Uses add_graph_execution with existing graph_exec_id to requeue. + + ⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits. + + Returns: + Number of executions requeued and success message + """ + logger.info(f"Admin {user.user_id} requeueing ALL stuck queued executions") + + # Fetch all stuck queued execution IDs + execution_ids = await get_all_stuck_queued_execution_ids() + + if not execution_ids: + return RequeueExecutionResponse( + success=True, + requeued_count=0, + message="No stuck queued executions to requeue", + ) + + # Get stuck executions by ID list (must be QUEUED) + executions = await get_graph_executions( + execution_ids=execution_ids, + statuses=[AgentExecutionStatus.QUEUED], + ) + + # Requeue all in parallel using add_graph_execution + async def requeue_one(exec) -> bool: + try: + await add_graph_execution( + graph_id=exec.graph_id, + user_id=exec.user_id, + graph_version=exec.graph_version, + graph_exec_id=exec.id, # Requeue existing + ) + return True + except Exception as e: + logger.error(f"Failed to requeue {exec.id}: {e}") + return False + + results = await asyncio.gather( + *[requeue_one(exec) for exec in executions], return_exceptions=False + ) + + requeued_count = sum(1 for success in results if success) + + return RequeueExecutionResponse( + success=requeued_count > 0, + requeued_count=requeued_count, + message=f"Requeued {requeued_count} stuck executions", + ) diff --git a/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py new file mode 100644 index 0000000000..a3783312b0 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/admin/diagnostics_admin_routes_test.py @@ -0,0 +1,889 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import fastapi +import fastapi.testclient +import pytest +import pytest_mock +from autogpt_libs.auth.jwt_utils import get_jwt_payload +from prisma.enums import AgentExecutionStatus + +import backend.api.features.admin.diagnostics_admin_routes as diagnostics_admin_routes +from backend.data.diagnostics import ( + AgentDiagnosticsSummary, + ExecutionDiagnosticsSummary, + FailedExecutionDetail, + OrphanedScheduleDetail, + RunningExecutionDetail, + ScheduleDetail, + ScheduleHealthMetrics, +) +from backend.data.execution import GraphExecutionMeta + +app = fastapi.FastAPI() +app.include_router(diagnostics_admin_routes.router) + +client = fastapi.testclient.TestClient(app) + + +@pytest.fixture(autouse=True) +def setup_app_admin_auth(mock_jwt_admin): + """Setup admin auth overrides for all tests in this module""" + app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"] + yield + app.dependency_overrides.clear() + + +def test_get_execution_diagnostics_success( + mocker: pytest_mock.MockFixture, +): + """Test fetching execution diagnostics with invalid state detection""" + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=1, + stuck_running_1h=3, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, # New invalid state + invalid_running_without_start=1, # New invalid state + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions") + + assert response.status_code == 200 + data = response.json() + + # Verify new invalid state fields are included + assert data["invalid_queued_with_start"] == 1 + assert data["invalid_running_without_start"] == 1 + # Verify all expected fields present + assert "running_executions" in data + assert "orphaned_running" in data + assert "failed_count_24h" in data + + +def test_list_invalid_executions( + mocker: pytest_mock.MockFixture, +): + """Test listing executions in invalid states (read-only endpoint)""" + mock_invalid_executions = [ + RunningExecutionDetail( + execution_id="exec-invalid-1", + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="QUEUED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now( + timezone.utc + ), # QUEUED but has startedAt - INVALID! + queue_status=None, + ), + RunningExecutionDetail( + execution_id="exec-invalid-2", + graph_id="graph-456", + graph_name="Another Graph", + graph_version=2, + user_id="user-456", + user_email="user@example.com", + status="RUNNING", + created_at=datetime.now(timezone.utc), + started_at=None, # RUNNING but no startedAt - INVALID! + queue_status=None, + ), + ] + + mock_diagnostics = ExecutionDiagnosticsSummary( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=0, + orphaned_queued=0, + failed_count_1h=0, + failed_count_24h=0, + failure_rate_24h=0.0, + stuck_running_24h=0, + stuck_running_1h=0, + oldest_running_hours=None, + stuck_queued_1h=0, + queued_never_started=0, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=0, + completed_24h=0, + throughput_per_hour=0.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_invalid_executions_details", + return_value=mock_invalid_executions, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=mock_diagnostics, + ) + + response = client.get("/admin/diagnostics/executions/invalid?limit=100&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # Sum of both invalid state types + assert len(data["executions"]) == 2 + # Verify both types of invalid states are returned + assert data["executions"][0]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + assert data["executions"][1]["execution_id"] in [ + "exec-invalid-1", + "exec-invalid-2", + ] + + +def test_requeue_single_execution_with_add_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test requeueing uses add_graph_execution in requeue mode""" + mock_exec_meta = GraphExecutionMeta( + id="exec-stuck-123", + user_id="user-123", + graph_id="graph-456", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_add_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-stuck-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 1 + + # Verify it used add_graph_execution in requeue mode + mock_add_graph_execution.assert_called_once() + call_kwargs = mock_add_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-stuck-123" # Requeue mode! + assert call_kwargs["graph_id"] == "graph-456" + assert call_kwargs["user_id"] == "user-123" + + +def test_stop_single_execution_with_stop_graph_execution( + mocker: pytest_mock.MockFixture, + admin_user_id: str, +): + """Test stopping uses robust stop_graph_execution""" + mock_exec_meta = GraphExecutionMeta( + id="exec-running-123", + user_id="user-789", + graph_id="graph-999", + graph_version=2, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[mock_exec_meta], + ) + + mock_stop_graph_execution = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 1 + + # Verify it used stop_graph_execution with cascade + mock_stop_graph_execution.assert_called_once() + call_kwargs = mock_stop_graph_execution.call_args.kwargs + assert call_kwargs["graph_exec_id"] == "exec-running-123" + assert call_kwargs["user_id"] == "user-789" + assert call_kwargs["cascade"] is True # Stops children too! + assert call_kwargs["wait_timeout"] == 15.0 + + +def test_requeue_not_queued_execution_fails( + mocker: pytest_mock.MockFixture, +): + """Test that requeue fails if execution is not in QUEUED status""" + # Mock an execution that's RUNNING (not QUEUED) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], # No QUEUED executions found + ) + + response = client.post( + "/admin/diagnostics/executions/requeue", + json={"execution_id": "exec-running-123"}, + ) + + assert response.status_code == 404 + assert "not found or not in QUEUED status" in response.json()["detail"] + + +def test_list_invalid_executions_no_bulk_actions( + mocker: pytest_mock.MockFixture, +): + """Verify invalid executions endpoint is read-only (no bulk actions)""" + # This is a documentation test - the endpoint exists but should not + # have corresponding cleanup/stop/requeue endpoints + + # These endpoints should NOT exist for invalid states: + invalid_bulk_endpoints = [ + "/admin/diagnostics/executions/cleanup-invalid", + "/admin/diagnostics/executions/stop-invalid", + "/admin/diagnostics/executions/requeue-invalid", + ] + + for endpoint in invalid_bulk_endpoints: + response = client.post(endpoint, json={"execution_ids": ["test"]}) + assert response.status_code == 404, f"{endpoint} should not exist (read-only)" + + +def test_execution_ids_filter_efficiency( + mocker: pytest_mock.MockFixture, +): + """Test that bulk operations use efficient execution_ids filter""" + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=datetime.now(timezone.utc), + ended_at=datetime.now(timezone.utc), + stats=None, + ) + for i in range(3) + ] + + mock_get_graph_executions = mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["exec-0", "exec-1", "exec-2"]}, + ) + + assert response.status_code == 200 + + # Verify it used execution_ids filter (not fetching all queued) + mock_get_graph_executions.assert_called_once() + call_kwargs = mock_get_graph_executions.call_args.kwargs + assert "execution_ids" in call_kwargs + assert call_kwargs["execution_ids"] == ["exec-0", "exec-1", "exec-2"] + assert call_kwargs["statuses"] == [AgentExecutionStatus.QUEUED] + + +# --------------------------------------------------------------------------- +# Helper: reusable mock diagnostics summary +# --------------------------------------------------------------------------- + + +def _make_mock_diagnostics(**overrides) -> ExecutionDiagnosticsSummary: + defaults = dict( + running_count=10, + queued_db_count=5, + rabbitmq_queue_depth=3, + cancel_queue_depth=0, + orphaned_running=2, + orphaned_queued=1, + failed_count_1h=5, + failed_count_24h=20, + failure_rate_24h=0.83, + stuck_running_24h=3, + stuck_running_1h=5, + oldest_running_hours=26.5, + stuck_queued_1h=2, + queued_never_started=1, + invalid_queued_with_start=1, + invalid_running_without_start=1, + completed_1h=50, + completed_24h=1200, + throughput_per_hour=50.0, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ExecutionDiagnosticsSummary(**defaults) + + +_SENTINEL = object() + + +def _make_mock_execution( + exec_id: str = "exec-1", + status: str = "RUNNING", + started_at: datetime | None | object = _SENTINEL, +) -> RunningExecutionDetail: + return RunningExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status=status, + created_at=datetime.now(timezone.utc), + started_at=( + datetime.now(timezone.utc) if started_at is _SENTINEL else started_at + ), + queue_status=None, + ) + + +def _make_mock_failed_execution( + exec_id: str = "exec-fail-1", +) -> FailedExecutionDetail: + return FailedExecutionDetail( + execution_id=exec_id, + graph_id="graph-123", + graph_name="Test Graph", + graph_version=1, + user_id="user-123", + user_email="test@example.com", + status="FAILED", + created_at=datetime.now(timezone.utc), + started_at=datetime.now(timezone.utc), + failed_at=datetime.now(timezone.utc), + error_message="Something went wrong", + ) + + +def _make_mock_schedule_health(**overrides) -> ScheduleHealthMetrics: + defaults = dict( + total_schedules=15, + user_schedules=10, + system_schedules=5, + orphaned_deleted_graph=2, + orphaned_no_library_access=1, + orphaned_invalid_credentials=0, + orphaned_validation_failed=0, + total_orphaned=3, + schedules_next_hour=4, + schedules_next_24h=8, + total_runs_next_hour=12, + total_runs_next_24h=48, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + defaults.update(overrides) + return ScheduleHealthMetrics(**defaults) + + +# --------------------------------------------------------------------------- +# GET endpoints: execution list variants +# --------------------------------------------------------------------------- + + +def test_list_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-run-1"), + _make_mock_execution("exec-run-2"), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/running?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 15 # running_count(10) + queued_db_count(5) + assert len(data["executions"]) == 2 + assert data["executions"][0]["execution_id"] == "exec-run-1" + + +def test_list_orphaned_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-orphan-1", status="RUNNING")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get("/admin/diagnostics/executions/orphaned?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # orphaned_running(2) + orphaned_queued(1) + assert len(data["executions"]) == 1 + + +def test_list_failed_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_failed_execution("exec-fail-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_failed_executions_count", + return_value=42, + ) + + response = client.get( + "/admin/diagnostics/executions/failed?limit=50&offset=0&hours=24" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 42 + assert len(data["executions"]) == 1 + assert data["executions"][0]["error_message"] == "Something went wrong" + + +def test_list_long_running_executions(mocker: pytest_mock.MockFixture): + mock_execs = [_make_mock_execution("exec-long-1")] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_long_running_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/long-running?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 # stuck_running_24h + assert len(data["executions"]) == 1 + + +def test_list_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mock_execs = [ + _make_mock_execution("exec-stuck-1", status="QUEUED", started_at=None) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_stuck_queued_executions_details", + return_value=mock_execs, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_execution_diagnostics", + return_value=_make_mock_diagnostics(), + ) + + response = client.get( + "/admin/diagnostics/executions/stuck-queued?limit=50&offset=0" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 # stuck_queued_1h + assert len(data["executions"]) == 1 + + +# --------------------------------------------------------------------------- +# GET endpoints: agent + schedule diagnostics +# --------------------------------------------------------------------------- + + +def test_get_agent_diagnostics(mocker: pytest_mock.MockFixture): + mock_diag = AgentDiagnosticsSummary( + agents_with_active_executions=7, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_agent_diagnostics", + return_value=mock_diag, + ) + + response = client.get("/admin/diagnostics/agents") + + assert response.status_code == 200 + data = response.json() + assert data["agents_with_active_executions"] == 7 + + +def test_get_schedule_diagnostics(mocker: pytest_mock.MockFixture): + mock_metrics = _make_mock_schedule_health() + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=mock_metrics, + ) + + response = client.get("/admin/diagnostics/schedules") + + assert response.status_code == 200 + data = response.json() + assert data["user_schedules"] == 10 + assert data["total_orphaned"] == 3 + assert data["total_runs_next_hour"] == 12 + + +def test_list_all_schedules(mocker: pytest_mock.MockFixture): + mock_schedules = [ + ScheduleDetail( + schedule_id="sched-1", + schedule_name="Daily Run", + graph_id="graph-1", + graph_name="My Agent", + graph_version=1, + user_id="user-1", + user_email="alice@example.com", + cron="0 9 * * *", + timezone="UTC", + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_schedules_details", + return_value=mock_schedules, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_schedule_health_metrics", + return_value=_make_mock_schedule_health(), + ) + + response = client.get("/admin/diagnostics/schedules/all?limit=50&offset=0") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 10 + assert len(data["schedules"]) == 1 + assert data["schedules"][0]["schedule_name"] == "Daily Run" + + +def test_list_orphaned_schedules(mocker: pytest_mock.MockFixture): + mock_orphans = [ + OrphanedScheduleDetail( + schedule_id="sched-orphan-1", + schedule_name="Ghost Schedule", + graph_id="graph-deleted", + graph_version=1, + user_id="user-1", + orphan_reason="deleted_graph", + error_detail=None, + next_run_time=datetime.now(timezone.utc).isoformat(), + ), + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_orphaned_schedules_details", + return_value=mock_orphans, + ) + + response = client.get("/admin/diagnostics/schedules/orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["schedules"][0]["orphan_reason"] == "deleted_graph" + + +# --------------------------------------------------------------------------- +# POST endpoints: bulk stop, cleanup, requeue +# --------------------------------------------------------------------------- + + +def test_stop_multiple_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ended_at=None, + stats=None, + ) + for i in range(2) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["exec-0", "exec-1"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_stop_multiple_executions_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["stopped_count"] == 0 + + +def test_cleanup_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=3, + ) + + response = client.post( + "/admin/diagnostics/executions/cleanup-orphaned", + json={"execution_ids": ["exec-1", "exec-2", "exec-3"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 3 + + +def test_cleanup_orphaned_schedules(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_schedules_bulk", + return_value=2, + ) + + response = client.post( + "/admin/diagnostics/schedules/cleanup-orphaned", + json={"schedule_ids": ["sched-1", "sched-2"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["deleted_count"] == 2 + + +def test_stop_all_long_running_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.stop_all_long_running_executions", + return_value=5, + ) + + response = client.post("/admin/diagnostics/executions/stop-all-long-running") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 5 + + +def test_cleanup_all_orphaned_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=["exec-1", "exec-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_orphaned_executions_bulk", + return_value=2, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 2 + + +def test_cleanup_all_orphaned_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_orphaned_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-orphaned") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 0 + assert "No orphaned" in data["message"] + + +def test_cleanup_all_stuck_queued_executions(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.cleanup_all_stuck_queued_executions", + return_value=4, + ) + + response = client.post("/admin/diagnostics/executions/cleanup-all-stuck-queued") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["stopped_count"] == 4 + + +def test_requeue_all_stuck_executions(mocker: pytest_mock.MockFixture): + mock_exec_metas = [ + GraphExecutionMeta( + id=f"exec-stuck-{i}", + user_id=f"user-{i}", + graph_id="graph-123", + graph_version=1, + inputs=None, + credential_inputs=None, + nodes_input_masks=None, + preset_id=None, + status=AgentExecutionStatus.QUEUED, + started_at=None, + ended_at=None, + stats=None, + ) + for i in range(3) + ] + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=["exec-stuck-0", "exec-stuck-1", "exec-stuck-2"], + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=mock_exec_metas, + ) + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.add_graph_execution", + return_value=AsyncMock(), + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 3 + + +def test_requeue_all_stuck_executions_none(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_all_stuck_queued_execution_ids", + return_value=[], + ) + + response = client.post("/admin/diagnostics/executions/requeue-all-stuck") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["requeued_count"] == 0 + assert "No stuck" in data["message"] + + +def test_requeue_bulk_none_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/requeue-bulk", + json={"execution_ids": ["nonexistent"]}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert data["requeued_count"] == 0 + + +def test_stop_single_execution_not_found(mocker: pytest_mock.MockFixture): + mocker.patch( + "backend.api.features.admin.diagnostics_admin_routes.get_graph_executions", + return_value=[], + ) + + response = client.post( + "/admin/diagnostics/executions/stop", + json={"execution_id": "nonexistent"}, + ) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] diff --git a/autogpt_platform/backend/backend/api/features/admin/model.py b/autogpt_platform/backend/backend/api/features/admin/model.py index 82f51e8e7a..c96c6d6433 100644 --- a/autogpt_platform/backend/backend/api/features/admin/model.py +++ b/autogpt_platform/backend/backend/api/features/admin/model.py @@ -14,3 +14,70 @@ class UserHistoryResponse(BaseModel): class AddUserCreditsResponse(BaseModel): new_balance: int transaction_key: str + + +class ExecutionDiagnosticsResponse(BaseModel): + """Response model for execution diagnostics""" + + # Current execution state + running_executions: int + queued_executions_db: int + queued_executions_rabbitmq: int + cancel_queue_depth: int + + # Orphaned execution detection + orphaned_running: int + orphaned_queued: int + + # Failure metrics + failed_count_1h: int + failed_count_24h: int + failure_rate_24h: float + + # Long-running detection + stuck_running_24h: int + stuck_running_1h: int + oldest_running_hours: float | None + + # Stuck queued detection + stuck_queued_1h: int + queued_never_started: int + + # Invalid state detection (data corruption - no auto-actions) + invalid_queued_with_start: int + invalid_running_without_start: int + + # Throughput metrics + completed_1h: int + completed_24h: int + throughput_per_hour: float + + timestamp: str + + +class AgentDiagnosticsResponse(BaseModel): + """Response model for agent diagnostics""" + + agents_with_active_executions: int + timestamp: str + + +class ScheduleHealthMetrics(BaseModel): + """Response model for schedule diagnostics""" + + total_schedules: int + user_schedules: int + system_schedules: int + + # Orphan detection + orphaned_deleted_graph: int + orphaned_no_library_access: int + orphaned_invalid_credentials: int + orphaned_validation_failed: int + total_orphaned: int + + # Upcoming + schedules_next_hour: int + schedules_next_24h: int + + timestamp: str diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 2b2dba397e..b4fc2da4e9 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -17,6 +17,7 @@ from fastapi.routing import APIRoute from prisma.errors import PrismaError import backend.api.features.admin.credit_admin_routes +import backend.api.features.admin.diagnostics_admin_routes import backend.api.features.admin.execution_analytics_routes import backend.api.features.admin.platform_cost_routes import backend.api.features.admin.rate_limit_admin_routes @@ -320,6 +321,11 @@ app.include_router( tags=["v2", "admin"], prefix="/api/credits", ) +app.include_router( + backend.api.features.admin.diagnostics_admin_routes.router, + tags=["v2", "admin"], + prefix="/api", +) app.include_router( backend.api.features.admin.execution_analytics_routes.router, tags=["v2", "admin"], diff --git a/autogpt_platform/backend/backend/data/diagnostics.py b/autogpt_platform/backend/backend/data/diagnostics.py new file mode 100644 index 0000000000..933f6c2a8a --- /dev/null +++ b/autogpt_platform/backend/backend/data/diagnostics.py @@ -0,0 +1,1215 @@ +""" +Diagnostics data layer for admin operations. +Provides functions to query and manage system diagnostics including executions and agents. +""" + +import asyncio +import logging +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from croniter import croniter +from prisma.enums import AgentExecutionStatus +from prisma.models import AgentGraph, AgentGraphExecution, LibraryAgent, User +from pydantic import BaseModel + +from backend.data.db import query_raw_with_schema +from backend.data.execution import get_graph_executions, get_graph_executions_count +from backend.data.rabbitmq import SyncRabbitMQ +from backend.executor.utils import ( + GRAPH_EXECUTION_CANCEL_EXCHANGE, + GRAPH_EXECUTION_CANCEL_QUEUE_NAME, + GRAPH_EXECUTION_QUEUE_NAME, + CancelExecutionEvent, + create_execution_queue_config, +) +from backend.util.clients import get_async_execution_queue, get_scheduler_client + +logger = logging.getLogger(__name__) + + +# System job IDs (exclude from user schedule counts) +SYSTEM_JOB_IDS = { + "cleanup_expired_files", + "report_late_executions", + "report_block_error_rates", + "process_existing_batches", + "process_weekly_summary", +} + + +class RunningExecutionDetail(BaseModel): + """Details about a running execution for admin view""" + + execution_id: str + graph_id: str + graph_name: str # Will default to "Unknown" if not available + graph_version: int + user_id: str + user_email: Optional[str] + status: str + created_at: datetime # When execution was created + started_at: Optional[datetime] # When execution started running + queue_status: Optional[str] = None + + +class FailedExecutionDetail(BaseModel): + """Details about a failed execution for admin view""" + + execution_id: str + graph_id: str + graph_name: str + graph_version: int + user_id: str + user_email: Optional[str] + status: str + created_at: datetime + started_at: Optional[datetime] + failed_at: Optional[datetime] + error_message: Optional[str] + + +class ExecutionDiagnosticsSummary(BaseModel): + """Summary of execution diagnostics""" + + # Current execution state + running_count: int + queued_db_count: int + rabbitmq_queue_depth: int + cancel_queue_depth: int + + # Orphaned execution detection (old DB records not in executor) + orphaned_running: int # Running but created >24h ago (likely orphaned) + orphaned_queued: int # Queued but created >24h ago (likely orphaned) + + # Failure metrics + failed_count_1h: int + failed_count_24h: int + failure_rate_24h: float # failures per hour over last 24h + + # Long-running detection (active executions) + stuck_running_24h: int # Running for more than 24 hours + stuck_running_1h: int # Running for more than 1 hour + oldest_running_hours: Optional[float] # Age of oldest running execution + + # Stuck queued detection + stuck_queued_1h: int # Queued for more than 1 hour + queued_never_started: int # Queued but started_at is null + + # Invalid state detection (data corruption - no auto-actions) + invalid_queued_with_start: int # QUEUED but has startedAt (impossible state) + invalid_running_without_start: int # RUNNING but no startedAt (impossible state) + + # Throughput metrics + completed_1h: int + completed_24h: int + throughput_per_hour: float # completions per hour over last 24h + + timestamp: str + + +class AgentDiagnosticsSummary(BaseModel): + """Summary of agent diagnostics""" + + agents_with_active_executions: int + timestamp: str + + +class ScheduleDetail(BaseModel): + """Details about a schedule for admin view""" + + schedule_id: str + schedule_name: str + graph_id: str + graph_name: str + graph_version: int + user_id: str + user_email: Optional[str] + cron: str + timezone: str + next_run_time: str + created_at: Optional[datetime] = None # Not available from APScheduler + + +class ScheduleHealthMetrics(BaseModel): + """Summary of schedule health diagnostics""" + + total_schedules: int + user_schedules: int # Excludes system monitoring jobs + system_schedules: int + + # Orphan detection + orphaned_deleted_graph: int + orphaned_no_library_access: int + orphaned_invalid_credentials: int + orphaned_validation_failed: int + total_orphaned: int + + # Upcoming schedules (unique count) + schedules_next_hour: int + schedules_next_24h: int + + # Upcoming execution runs (total count) + total_runs_next_hour: int + total_runs_next_24h: int + + timestamp: str + + +class OrphanedScheduleDetail(BaseModel): + """Details about an orphaned schedule""" + + schedule_id: str + schedule_name: str + graph_id: str + graph_version: int + user_id: str + orphan_reason: ( + str # deleted_graph, no_library_access, invalid_credentials, validation_failed + ) + error_detail: Optional[str] + next_run_time: str + + +def _to_running_execution_detail( + exec: AgentGraphExecution, +) -> RunningExecutionDetail: + """Convert a Prisma AgentGraphExecution (with includes) to RunningExecutionDetail.""" + return RunningExecutionDetail( + execution_id=exec.id, + graph_id=exec.agentGraphId, + graph_name=( + exec.AgentGraph.name + if exec.AgentGraph and exec.AgentGraph.name + else "Unknown" + ), + graph_version=exec.agentGraphVersion, + user_id=exec.userId, + user_email=exec.User.email if exec.User else None, + status=exec.executionStatus, + created_at=exec.createdAt, + started_at=exec.startedAt, + ) + + +_EXECUTION_ADMIN_INCLUDE = { + "AgentGraph": True, + "User": True, +} + + +async def get_execution_diagnostics() -> ExecutionDiagnosticsSummary: + """ + Get comprehensive execution diagnostics including database and queue metrics. + Uses a single batched SQL query for all count metrics to minimize DB round-trips. + + Returns: + ExecutionDiagnosticsSummary with current execution state + """ + now = datetime.now(timezone.utc) + one_hour_ago = now - timedelta(hours=1) + twenty_four_hours_ago = now - timedelta(hours=24) + + # Single SQL query to get all count metrics at once + counts = await query_raw_with_schema( + """ + SELECT + COUNT(*) FILTER ( + WHERE "executionStatus" = 'RUNNING' + ) AS running_count, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'QUEUED' + ) AS queued_db_count, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'RUNNING' + AND "createdAt" < $1::timestamp + ) AS orphaned_running, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'QUEUED' + AND "createdAt" < $1::timestamp + ) AS orphaned_queued, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'FAILED' + AND "updatedAt" >= $2::timestamp + ) AS failed_count_1h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'FAILED' + AND "updatedAt" >= $1::timestamp + ) AS failed_count_24h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'RUNNING' + AND "startedAt" IS NOT NULL + AND "startedAt" < $1::timestamp + ) AS stuck_running_24h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'RUNNING' + AND "startedAt" IS NOT NULL + AND "startedAt" < $2::timestamp + ) AS stuck_running_1h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'QUEUED' + AND "createdAt" < $2::timestamp + ) AS stuck_queued_1h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'QUEUED' + AND "startedAt" IS NULL + ) AS queued_never_started, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'QUEUED' + AND "startedAt" IS NOT NULL + ) AS invalid_queued_with_start, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'RUNNING' + AND "startedAt" IS NULL + ) AS invalid_running_without_start, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'COMPLETED' + AND "updatedAt" >= $2::timestamp + ) AS completed_1h, + COUNT(*) FILTER ( + WHERE "executionStatus" = 'COMPLETED' + AND "updatedAt" >= $1::timestamp + ) AS completed_24h + FROM {schema_prefix}"AgentGraphExecution" + WHERE "isDeleted" = false + """, + twenty_four_hours_ago, + one_hour_ago, + ) + + row = counts[0] if counts else {} + + running_count = row.get("running_count", 0) + queued_db_count = row.get("queued_db_count", 0) + orphaned_running = row.get("orphaned_running", 0) + orphaned_queued = row.get("orphaned_queued", 0) + failed_count_1h = row.get("failed_count_1h", 0) + failed_count_24h = row.get("failed_count_24h", 0) + stuck_running_24h = row.get("stuck_running_24h", 0) + stuck_running_1h = row.get("stuck_running_1h", 0) + stuck_queued_1h = row.get("stuck_queued_1h", 0) + queued_never_started = row.get("queued_never_started", 0) + invalid_queued_with_start = row.get("invalid_queued_with_start", 0) + invalid_running_without_start = row.get("invalid_running_without_start", 0) + completed_1h = row.get("completed_1h", 0) + completed_24h = row.get("completed_24h", 0) + + failure_rate_24h = failed_count_24h / 24.0 if failed_count_24h > 0 else 0.0 + throughput_per_hour = completed_24h / 24.0 if completed_24h > 0 else 0.0 + + # RabbitMQ queue depths (blocking sync calls, run in thread pool) + rabbitmq_queue_depth, cancel_queue_depth = await asyncio.gather( + asyncio.to_thread(get_rabbitmq_queue_depth), + asyncio.to_thread(get_rabbitmq_cancel_queue_depth), + ) + + # Find oldest running execution (single query) + oldest_running_list = await get_graph_executions( + statuses=[AgentExecutionStatus.RUNNING], + order_by="startedAt", + order_direction="asc", + limit=1, + ) + + oldest_running_hours = None + if oldest_running_list and oldest_running_list[0].started_at: + age_seconds = (now - oldest_running_list[0].started_at).total_seconds() + oldest_running_hours = age_seconds / 3600.0 + + return ExecutionDiagnosticsSummary( + running_count=running_count, + queued_db_count=queued_db_count, + rabbitmq_queue_depth=rabbitmq_queue_depth, + cancel_queue_depth=cancel_queue_depth, + orphaned_running=orphaned_running, + orphaned_queued=orphaned_queued, + failed_count_1h=failed_count_1h, + failed_count_24h=failed_count_24h, + failure_rate_24h=failure_rate_24h, + stuck_running_24h=stuck_running_24h, + stuck_running_1h=stuck_running_1h, + oldest_running_hours=oldest_running_hours, + stuck_queued_1h=stuck_queued_1h, + queued_never_started=queued_never_started, + invalid_queued_with_start=invalid_queued_with_start, + invalid_running_without_start=invalid_running_without_start, + completed_1h=completed_1h, + completed_24h=completed_24h, + throughput_per_hour=throughput_per_hour, + timestamp=now.isoformat(), + ) + + +async def get_agent_diagnostics() -> AgentDiagnosticsSummary: + """ + Get comprehensive agent diagnostics. + + Returns: + AgentDiagnosticsSummary with agent metrics + """ + # Single query to count distinct agents with active executions + result = await query_raw_with_schema( + """ + SELECT COUNT(DISTINCT "agentGraphId") AS active_agents + FROM {schema_prefix}"AgentGraphExecution" + WHERE "executionStatus" IN ('RUNNING', 'QUEUED') + AND "isDeleted" = false + """ + ) + + active_agents = result[0].get("active_agents", 0) if result else 0 + + return AgentDiagnosticsSummary( + agents_with_active_executions=active_agents, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + +async def get_schedule_health_metrics() -> ScheduleHealthMetrics: + """ + Get comprehensive schedule diagnostics via Scheduler service. + + Returns: + ScheduleHealthMetrics with schedule health info + """ + scheduler = get_scheduler_client() + + # Get all schedules from scheduler service + all_schedules = await scheduler.get_execution_schedules() + + # Filter user vs system schedules + user_schedules = [s for s in all_schedules if s.id not in SYSTEM_JOB_IDS] + system_schedules_count = len(all_schedules) - len(user_schedules) + + # Detect orphaned schedules + orphans = await _detect_orphaned_schedules(user_schedules) + + # Count schedules by next run time (exclude orphaned schedules) + now = datetime.now(timezone.utc) + one_hour_from_now = now + timedelta(hours=1) + twenty_four_hours_from_now = now + timedelta(hours=24) + + orphaned_ids = set() + for category_ids in orphans.values(): + orphaned_ids.update(category_ids) + + healthy_schedules = [s for s in user_schedules if s.id not in orphaned_ids] + + schedules_next_hour = sum( + 1 + for s in healthy_schedules + if s.next_run_time + and datetime.fromisoformat(s.next_run_time.replace("Z", "+00:00")) + <= one_hour_from_now + ) + + schedules_next_24h = sum( + 1 + for s in healthy_schedules + if s.next_run_time + and datetime.fromisoformat(s.next_run_time.replace("Z", "+00:00")) + <= twenty_four_hours_from_now + ) + + # Calculate total execution runs (not just unique schedules, exclude orphaned) + total_runs_next_hour = _calculate_total_runs( + healthy_schedules, now, one_hour_from_now + ) + total_runs_next_24h = _calculate_total_runs( + healthy_schedules, now, twenty_four_hours_from_now + ) + + return ScheduleHealthMetrics( + total_schedules=len(all_schedules), + user_schedules=len(user_schedules), + system_schedules=system_schedules_count, + orphaned_deleted_graph=len(orphans["deleted_graph"]), + orphaned_no_library_access=len(orphans["no_library_access"]), + orphaned_invalid_credentials=len(orphans["invalid_credentials"]), + orphaned_validation_failed=len(orphans["validation_failed"]), + total_orphaned=sum(len(v) for v in orphans.values()), + schedules_next_hour=schedules_next_hour, + schedules_next_24h=schedules_next_24h, + total_runs_next_hour=total_runs_next_hour, + total_runs_next_24h=total_runs_next_24h, + timestamp=now.isoformat(), + ) + + +def _calculate_total_runs( + schedules: list, start_time: datetime, end_time: datetime +) -> int: + """ + Calculate total number of scheduled executions in time window. + + Args: + schedules: List of GraphExecutionJobInfo with cron expressions + start_time: Start of time window + end_time: End of time window + + Returns: + Total number of execution runs across all schedules + """ + total_runs = 0 + + for schedule in schedules: + try: + # Create cron iterator + iter = croniter(schedule.cron, start_time) + + # Count occurrences in window (with safety limit) + count = 0 + max_iterations = 2000 # Safety limit (e.g., every-minute for 24h = 1440) + + while count < max_iterations: + try: + next_run = iter.get_next(datetime) + if next_run > end_time: + break + count += 1 + except Exception: + # Handle edge cases like invalid cron progression + break + + total_runs += count + + except Exception as e: + logger.warning(f"Failed to parse cron expression '{schedule.cron}': {e}") + # Skip this schedule if cron is invalid + continue + + return total_runs + + +async def _detect_orphaned_schedules(schedules: list) -> dict: + """ + Detect orphaned schedules by validating graph, library access, and credentials. + + Args: + schedules: List of GraphExecutionJobInfo from scheduler service + + Returns: + Dict categorizing orphans by type + """ + orphans = { + "deleted_graph": [], + "no_library_access": [], + "invalid_credentials": [], + "validation_failed": [], + } + + for schedule in schedules: + try: + # Check 1: Graph exists + graph = await AgentGraph.prisma().find_unique( + where={ + "graphVersionId": { + "id": schedule.graph_id, + "version": schedule.graph_version, + } + } + ) + + if not graph: + orphans["deleted_graph"].append(schedule.id) + continue + + # Check 2: User has library access (not deleted/archived) + library_agent = await LibraryAgent.prisma().find_first( + where={ + "userId": schedule.user_id, + "agentGraphId": schedule.graph_id, + "isDeleted": False, + "isArchived": False, + } + ) + + if not library_agent: + orphans["no_library_access"].append(schedule.id) + continue + + # Check 3: Credentials exist (if any) + # Note: Full credential validation would require integration_creds_manager + # For now, skip credential validation to avoid complexity + # Orphaned credentials will be caught during execution attempt + + except Exception as e: + logger.error(f"Error validating schedule {schedule.id}: {e}") + orphans["validation_failed"].append(schedule.id) + + return orphans + + +def get_rabbitmq_queue_depth() -> int: + """ + Get the number of messages in the RabbitMQ execution queue. + + Returns: + Number of messages in queue, or -1 if error + """ + try: + # Create a temporary connection to query the queue + config = create_execution_queue_config() + rabbitmq = SyncRabbitMQ(config) + rabbitmq.connect() + + try: + # Use passive queue_declare to get queue info without modifying it + if rabbitmq._channel: + method_frame = rabbitmq._channel.queue_declare( + queue=GRAPH_EXECUTION_QUEUE_NAME, passive=True + ) + else: + raise RuntimeError("RabbitMQ channel not initialized") + + return method_frame.method.message_count + finally: + # Always clean up connection, even on error + try: + rabbitmq.disconnect() + except Exception as disconnect_err: + logger.warning( + f"Failed to close RabbitMQ connection after queue depth check: {disconnect_err}" + ) + except Exception as e: + logger.error(f"Error getting RabbitMQ queue depth: {e}") + # Return -1 to indicate an error state rather than failing the entire request + return -1 + + +def get_rabbitmq_cancel_queue_depth() -> int: + """ + Get the number of messages in the RabbitMQ cancel queue. + + Returns: + Number of messages in cancel queue, or -1 if error + """ + try: + # Create a temporary connection to query the queue + config = create_execution_queue_config() + rabbitmq = SyncRabbitMQ(config) + rabbitmq.connect() + + try: + # Use passive queue_declare to get queue info without modifying it + if rabbitmq._channel: + method_frame = rabbitmq._channel.queue_declare( + queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME, passive=True + ) + else: + raise RuntimeError("RabbitMQ channel not initialized") + + return method_frame.method.message_count + finally: + # Always clean up connection, even on error + try: + rabbitmq.disconnect() + except Exception as disconnect_err: + logger.warning( + f"Failed to close RabbitMQ connection after cancel queue check: {disconnect_err}" + ) + except Exception as e: + logger.error(f"Error getting RabbitMQ cancel queue depth: {e}") + # Return -1 to indicate an error state rather than failing the entire request + return -1 + + +async def get_all_schedules_details( + limit: int = 100, offset: int = 0 +) -> List[ScheduleDetail]: + """ + Get detailed information about all user schedules via Scheduler service. + + Args: + limit: Maximum number of schedules to return + offset: Number of schedules to skip + + Returns: + List of ScheduleDetail objects + """ + scheduler = get_scheduler_client() + + # Get all schedules from scheduler + all_schedules = await scheduler.get_execution_schedules() + + # Filter to user schedules only + user_schedules = [s for s in all_schedules if s.id not in SYSTEM_JOB_IDS] + + # Apply pagination + paginated_schedules = user_schedules[offset : offset + limit] + + # Enrich with graph and user details + results = [] + for schedule in paginated_schedules: + # Get graph name + graph = await AgentGraph.prisma().find_unique( + where={ + "graphVersionId": { + "id": schedule.graph_id, + "version": schedule.graph_version, + } + }, + ) + + graph_name = graph.name if graph and graph.name else "Unknown" + + # Fetch user by schedule creator's user_id (not graph owner) + schedule_user = await User.prisma().find_unique(where={"id": schedule.user_id}) + user_email = schedule_user.email if schedule_user else None + + results.append( + ScheduleDetail( + schedule_id=schedule.id, + schedule_name=schedule.name, + graph_id=schedule.graph_id, + graph_name=graph_name, + graph_version=schedule.graph_version, + user_id=schedule.user_id, + user_email=user_email, + cron=schedule.cron, + timezone=schedule.timezone, + next_run_time=schedule.next_run_time, + ) + ) + + return results + + +async def get_orphaned_schedules_details() -> List[OrphanedScheduleDetail]: + """ + Get detailed list of orphaned schedules with orphan reasons. + + Returns: + List of OrphanedScheduleDetail objects + """ + scheduler = get_scheduler_client() + + # Get all schedules + all_schedules = await scheduler.get_execution_schedules() + user_schedules = [s for s in all_schedules if s.id not in SYSTEM_JOB_IDS] + + # Detect orphans with categorization + orphan_categories = await _detect_orphaned_schedules(user_schedules) + + # Build detailed orphan list + results = [] + for orphan_type, schedule_ids in orphan_categories.items(): + for schedule_id in schedule_ids: + # Find the schedule + schedule = next((s for s in user_schedules if s.id == schedule_id), None) + if not schedule: + continue + + results.append( + OrphanedScheduleDetail( + schedule_id=schedule.id, + schedule_name=schedule.name, + graph_id=schedule.graph_id, + graph_version=schedule.graph_version, + user_id=schedule.user_id, + orphan_reason=orphan_type, + error_detail=None, # Could add more detail in future + next_run_time=schedule.next_run_time, + ) + ) + + return results + + +async def cleanup_orphaned_schedules_bulk( + schedule_ids: List[str], admin_user_id: str +) -> int: + """ + Cleanup multiple orphaned schedules by deleting from scheduler. + + Args: + schedule_ids: List of schedule IDs to delete + admin_user_id: ID of the admin user performing the operation + + Returns: + Number of schedules successfully deleted + """ + logger.info( + f"Admin user {admin_user_id} cleaning up {len(schedule_ids)} orphaned schedules" + ) + + scheduler = get_scheduler_client() + + # Fetch all schedules once to avoid N+1 queries + all_schedules = await scheduler.get_execution_schedules() + schedule_map = {s.id: s for s in all_schedules} + + # Delete schedules in parallel + async def delete_schedule(schedule_id: str) -> bool: + schedule = schedule_map.get(schedule_id) + if not schedule: + logger.warning(f"Schedule {schedule_id} not found") + return False + + try: + await scheduler.delete_schedule( + schedule_id=schedule_id, user_id=schedule.user_id + ) + return True + except Exception as e: + logger.error(f"Failed to delete schedule {schedule_id}: {e}") + return False + + results = await asyncio.gather( + *[delete_schedule(schedule_id) for schedule_id in schedule_ids], + return_exceptions=False, + ) + + deleted_count = sum(1 for success in results if success) + + logger.info( + f"Admin {admin_user_id} deleted {deleted_count}/{len(schedule_ids)} orphaned schedules" + ) + + return deleted_count + + +async def get_running_executions_details( + limit: int = 100, offset: int = 0 +) -> List[RunningExecutionDetail]: + """ + Get detailed information about running and queued executions. + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + + Returns: + List of RunningExecutionDetail objects + """ + executions = await AgentGraphExecution.prisma().find_many( + where={ + "executionStatus": { + "in": [AgentExecutionStatus.RUNNING, AgentExecutionStatus.QUEUED] # type: ignore + }, + "isDeleted": False, + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"createdAt": "desc"}, + ) + + return [_to_running_execution_detail(e) for e in executions] + + +async def get_orphaned_executions_details( + limit: int = 100, offset: int = 0 +) -> List[RunningExecutionDetail]: + """ + Get detailed information about orphaned executions (>24h old, likely not in executor). + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + + Returns: + List of orphaned RunningExecutionDetail objects + """ + cutoff = datetime.now(timezone.utc) - timedelta(hours=24) + + executions = await AgentGraphExecution.prisma().find_many( + where={ + "executionStatus": { + "in": [AgentExecutionStatus.RUNNING, AgentExecutionStatus.QUEUED] # type: ignore + }, + "createdAt": {"lt": cutoff}, + "isDeleted": False, + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"createdAt": "asc"}, + ) + + return [_to_running_execution_detail(e) for e in executions] + + +async def get_long_running_executions_details( + limit: int = 100, offset: int = 0 +) -> List[RunningExecutionDetail]: + """ + Get detailed information about long-running executions (RUNNING status >24h). + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + + Returns: + List of long-running RunningExecutionDetail objects + """ + cutoff = datetime.now(timezone.utc) - timedelta(hours=24) + + executions = await AgentGraphExecution.prisma().find_many( + where={ + "executionStatus": AgentExecutionStatus.RUNNING, + "startedAt": {"lt": cutoff}, + "isDeleted": False, + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"startedAt": "asc"}, + ) + + return [_to_running_execution_detail(e) for e in executions] + + +async def get_stuck_queued_executions_details( + limit: int = 100, offset: int = 0 +) -> List[RunningExecutionDetail]: + """ + Get detailed information about stuck queued executions (QUEUED >1h, never started). + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + + Returns: + List of stuck queued RunningExecutionDetail objects + """ + one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1) + + executions = await AgentGraphExecution.prisma().find_many( + where={ + "executionStatus": AgentExecutionStatus.QUEUED, + "createdAt": {"lt": one_hour_ago}, + "isDeleted": False, + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"createdAt": "asc"}, + ) + + return [_to_running_execution_detail(e) for e in executions] + + +async def get_invalid_executions_details( + limit: int = 100, offset: int = 0 +) -> List[RunningExecutionDetail]: + """ + Get detailed information about executions in invalid states. + + Invalid states are data corruption issues that require manual investigation: + - QUEUED but has startedAt (impossible - can't start while queued) + - RUNNING but no startedAt (impossible - can't run without starting) + + NO bulk actions provided - these need case-by-case investigation. + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + + Returns: + List of invalid RunningExecutionDetail objects + """ + executions = await AgentGraphExecution.prisma().find_many( + where={ + "isDeleted": False, + "OR": [ # type: ignore + { + "executionStatus": AgentExecutionStatus.QUEUED, + "startedAt": {"not": None}, # type: ignore + }, + { + "executionStatus": AgentExecutionStatus.RUNNING, + "startedAt": None, + }, + ], + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"createdAt": "desc"}, + ) + + return [_to_running_execution_detail(e) for e in executions] + + +async def get_failed_executions_count(hours: int = 24) -> int: + """ + Get count of failed executions within the specified time window. + + Args: + hours: Number of hours to look back (default 24) + + Returns: + Count of failed executions + """ + cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + count = await get_graph_executions_count( + statuses=[AgentExecutionStatus.FAILED], + updated_time_gte=cutoff, + ) + return count + + +async def get_failed_executions_details( + limit: int = 100, offset: int = 0, hours: int = 24 +) -> List[FailedExecutionDetail]: + """ + Get detailed information about failed executions. + + Args: + limit: Maximum number of executions to return + offset: Number of executions to skip + hours: Number of hours to look back (default 24) + + Returns: + List of FailedExecutionDetail objects + """ + cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + + executions = await AgentGraphExecution.prisma().find_many( + where={ + "executionStatus": AgentExecutionStatus.FAILED, + "updatedAt": {"gte": cutoff}, + "isDeleted": False, + }, + include=_EXECUTION_ADMIN_INCLUDE, + take=limit, + skip=offset, + order={"updatedAt": "desc"}, # Most recent failures first + ) + + results = [] + for exec in executions: + # Extract error from stats JSON field + error_message = None + if exec.stats and isinstance(exec.stats, dict): + error_message = exec.stats.get("error") + + results.append( + FailedExecutionDetail( + execution_id=exec.id, + graph_id=exec.agentGraphId, + graph_name=( + exec.AgentGraph.name + if exec.AgentGraph and exec.AgentGraph.name + else "Unknown" + ), + graph_version=exec.agentGraphVersion, + user_id=exec.userId, + user_email=exec.User.email if exec.User else None, + status=exec.executionStatus, + created_at=exec.createdAt, + started_at=exec.startedAt, + failed_at=exec.updatedAt, + error_message=error_message, + ) + ) + + return results + + +async def cleanup_orphaned_execution(execution_id: str, admin_user_id: str) -> bool: + """ + Cleanup orphaned execution by directly updating DB status. + For executions that are in DB but not actually running in executor. + + Args: + execution_id: ID of the execution to cleanup + admin_user_id: ID of the admin user performing the operation + + Returns: + True if execution was cleaned up, False otherwise + """ + logger.info( + f"Admin user {admin_user_id} cleaning up orphaned execution {execution_id}" + ) + + # Update DB status directly without sending cancel signal + result = await AgentGraphExecution.prisma().update( + where={"id": execution_id}, + data={ + "executionStatus": AgentExecutionStatus.FAILED, + "updatedAt": datetime.now(timezone.utc), + }, + ) + + logger.info( + f"Admin {admin_user_id} marked orphaned execution {execution_id} as FAILED" + ) + return result is not None + + +async def stop_all_long_running_executions(admin_user_id: str) -> int: + """ + Stop ALL long-running executions (RUNNING >24h) by sending cancel signals. + + Args: + admin_user_id: ID of the admin user performing the operation + + Returns: + Number of executions for which cancel signals were sent + """ + logger.info(f"Admin user {admin_user_id} stopping ALL long-running executions") + + # Find all long-running executions (started running >24h ago) + cutoff = datetime.now(timezone.utc) - timedelta(hours=24) + executions = await get_graph_executions( + statuses=[AgentExecutionStatus.RUNNING], + started_time_lte=cutoff, + ) + + if not executions: + logger.info("No long-running executions to stop") + return 0 + + queue_client = await get_async_execution_queue() + + # Send cancel signals in parallel + async def send_cancel_signal(exec_id: str) -> bool: + try: + await queue_client.publish_message( + routing_key="", + message=CancelExecutionEvent(graph_exec_id=exec_id).model_dump_json(), + exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE, + ) + return True + except Exception as e: + logger.error(f"Failed to send cancel for {exec_id}: {e}") + return False + + # Send cancel signals in parallel + await asyncio.gather( + *[send_cancel_signal(exec.id) for exec in executions], + return_exceptions=True, # Don't fail if some signals fail + ) + + # ALSO update DB status directly (don't rely on executor) + # This ensures executions are marked FAILED even if executor restarted + result = await AgentGraphExecution.prisma().update_many( + where={ + "executionStatus": AgentExecutionStatus.RUNNING, + "startedAt": {"lt": cutoff}, + "isDeleted": False, + }, + data={ + "executionStatus": AgentExecutionStatus.FAILED, + "updatedAt": datetime.now(timezone.utc), + }, + ) + + logger.info( + f"Admin {admin_user_id} stopped {result} long-running executions (sent cancel signals + updated DB)" + ) + + return result + + +async def get_all_orphaned_execution_ids() -> List[str]: + """ + Get all orphaned execution IDs (>24h old, RUNNING or QUEUED). + + Returns: + List of execution IDs that are orphaned + """ + cutoff = datetime.now(timezone.utc) - timedelta(hours=24) + + executions = await get_graph_executions( + statuses=[AgentExecutionStatus.RUNNING, AgentExecutionStatus.QUEUED], + created_time_lte=cutoff, + ) + + return [e.id for e in executions] + + +async def cleanup_orphaned_executions_bulk( + execution_ids: List[str], admin_user_id: str +) -> int: + """ + Cleanup multiple orphaned executions by directly updating DB status. + For executions in DB but not actually running in executor (old/orphaned). + + Args: + execution_ids: List of execution IDs to cleanup + admin_user_id: ID of the admin user performing the operation + + Returns: + Number of executions successfully cleaned up + """ + logger.info( + f"Admin user {admin_user_id} cleaning up {len(execution_ids)} orphaned executions" + ) + + # Update all executions in DB directly (no cancel signals) + # Only update executions still in RUNNING/QUEUED status to avoid + # overwriting a legitimately COMPLETED execution (TOCTOU guard) + result = await AgentGraphExecution.prisma().update_many( + where={ + "id": {"in": execution_ids}, + "isDeleted": False, + "executionStatus": { + "in": [AgentExecutionStatus.RUNNING, AgentExecutionStatus.QUEUED] + }, + }, + data={ + "executionStatus": AgentExecutionStatus.FAILED, + "updatedAt": datetime.now(timezone.utc), + }, + ) + + logger.info( + f"Admin {admin_user_id} marked {result} orphaned executions as FAILED in DB" + ) + + return result + + +async def get_all_stuck_queued_execution_ids() -> List[str]: + """ + Get all stuck queued execution IDs (QUEUED >1h). + + Returns: + List of execution IDs that are stuck in QUEUED status + """ + one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1) + + executions = await get_graph_executions( + statuses=[AgentExecutionStatus.QUEUED], + created_time_lte=one_hour_ago, + ) + + return [e.id for e in executions] + + +async def cleanup_all_stuck_queued_executions(admin_user_id: str) -> int: + """ + Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status. + Operates on all stuck queued executions, not just paginated results. + + Args: + admin_user_id: ID of the admin user performing the operation + + Returns: + Number of executions successfully cleaned up + """ + logger.info(f"Admin user {admin_user_id} cleaning up ALL stuck queued executions") + + # Find all stuck queued executions (>1h old) + one_hour_ago = datetime.now(timezone.utc) - timedelta(hours=1) + + result = await AgentGraphExecution.prisma().update_many( + where={ + "executionStatus": AgentExecutionStatus.QUEUED, + "createdAt": {"lt": one_hour_ago}, + "isDeleted": False, + }, + data={ + "executionStatus": AgentExecutionStatus.FAILED, + "updatedAt": datetime.now(timezone.utc), + }, + ) + + logger.info( + f"Admin {admin_user_id} marked {result} stuck queued executions as FAILED in DB" + ) + + return result diff --git a/autogpt_platform/backend/backend/data/diagnostics_test.py b/autogpt_platform/backend/backend/data/diagnostics_test.py new file mode 100644 index 0000000000..fc52070411 --- /dev/null +++ b/autogpt_platform/backend/backend/data/diagnostics_test.py @@ -0,0 +1,464 @@ +"""Unit tests for diagnostics data layer functions.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.data.diagnostics import ( + _calculate_total_runs, + _detect_orphaned_schedules, + get_execution_diagnostics, + get_rabbitmq_cancel_queue_depth, + get_rabbitmq_queue_depth, +) + +# --------------------------------------------------------------------------- +# get_execution_diagnostics tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_execution_diagnostics_full(): + """Test get_execution_diagnostics aggregates all data correctly.""" + mock_row = { + "running_count": 10, + "queued_db_count": 5, + "orphaned_running": 2, + "orphaned_queued": 1, + "failed_count_1h": 3, + "failed_count_24h": 12, + "stuck_running_24h": 1, + "stuck_running_1h": 2, + "stuck_queued_1h": 4, + "queued_never_started": 3, + "invalid_queued_with_start": 1, + "invalid_running_without_start": 0, + "completed_1h": 50, + "completed_24h": 600, + } + + mock_exec = MagicMock() + mock_exec.started_at = datetime.now(timezone.utc) - timedelta(hours=48) + + with ( + patch( + "backend.data.diagnostics.query_raw_with_schema", + new_callable=AsyncMock, + return_value=[mock_row], + ), + patch( + "backend.data.diagnostics.get_rabbitmq_queue_depth", + return_value=7, + ), + patch( + "backend.data.diagnostics.get_rabbitmq_cancel_queue_depth", + return_value=2, + ), + patch( + "backend.data.diagnostics.get_graph_executions", + new_callable=AsyncMock, + return_value=[mock_exec], + ), + ): + result = await get_execution_diagnostics() + + assert result.running_count == 10 + assert result.queued_db_count == 5 + assert result.orphaned_running == 2 + assert result.orphaned_queued == 1 + assert result.failed_count_1h == 3 + assert result.failed_count_24h == 12 + assert result.failure_rate_24h == 12 / 24.0 + assert result.stuck_running_24h == 1 + assert result.stuck_running_1h == 2 + assert result.stuck_queued_1h == 4 + assert result.queued_never_started == 3 + assert result.invalid_queued_with_start == 1 + assert result.invalid_running_without_start == 0 + assert result.completed_1h == 50 + assert result.completed_24h == 600 + assert result.throughput_per_hour == 600 / 24.0 + assert result.rabbitmq_queue_depth == 7 + assert result.cancel_queue_depth == 2 + assert result.oldest_running_hours is not None + assert result.oldest_running_hours > 47.0 + + +@pytest.mark.asyncio +async def test_get_execution_diagnostics_empty_db(): + """Test get_execution_diagnostics with empty database.""" + with ( + patch( + "backend.data.diagnostics.query_raw_with_schema", + new_callable=AsyncMock, + return_value=[{}], + ), + patch( + "backend.data.diagnostics.get_rabbitmq_queue_depth", + return_value=-1, + ), + patch( + "backend.data.diagnostics.get_rabbitmq_cancel_queue_depth", + return_value=-1, + ), + patch( + "backend.data.diagnostics.get_graph_executions", + new_callable=AsyncMock, + return_value=[], + ), + ): + result = await get_execution_diagnostics() + + assert result.running_count == 0 + assert result.queued_db_count == 0 + assert result.failure_rate_24h == 0.0 + assert result.throughput_per_hour == 0.0 + assert result.oldest_running_hours is None + assert result.rabbitmq_queue_depth == -1 + assert result.cancel_queue_depth == -1 + + +@pytest.mark.asyncio +async def test_get_execution_diagnostics_no_started_at(): + """Test oldest_running_hours when oldest execution has no started_at.""" + mock_row = { + "running_count": 1, + "queued_db_count": 0, + "orphaned_running": 0, + "orphaned_queued": 0, + "failed_count_1h": 0, + "failed_count_24h": 0, + "stuck_running_24h": 0, + "stuck_running_1h": 0, + "stuck_queued_1h": 0, + "queued_never_started": 0, + "invalid_queued_with_start": 0, + "invalid_running_without_start": 1, + "completed_1h": 0, + "completed_24h": 0, + } + + mock_exec = MagicMock() + mock_exec.started_at = None + + with ( + patch( + "backend.data.diagnostics.query_raw_with_schema", + new_callable=AsyncMock, + return_value=[mock_row], + ), + patch( + "backend.data.diagnostics.get_rabbitmq_queue_depth", + return_value=0, + ), + patch( + "backend.data.diagnostics.get_rabbitmq_cancel_queue_depth", + return_value=0, + ), + patch( + "backend.data.diagnostics.get_graph_executions", + new_callable=AsyncMock, + return_value=[mock_exec], + ), + ): + result = await get_execution_diagnostics() + + assert result.oldest_running_hours is None + + +# --------------------------------------------------------------------------- +# RabbitMQ queue depth tests +# --------------------------------------------------------------------------- + + +def test_rabbitmq_queue_depth_success(): + """Test successful RabbitMQ queue depth retrieval.""" + mock_method_frame = MagicMock() + mock_method_frame.method.message_count = 42 + + mock_channel = MagicMock() + mock_channel.queue_declare.return_value = mock_method_frame + + mock_rabbitmq = MagicMock() + mock_rabbitmq._channel = mock_channel + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + result = get_rabbitmq_queue_depth() + + assert result == 42 + mock_rabbitmq.connect.assert_called_once() + mock_rabbitmq.disconnect.assert_called_once() + + +def test_rabbitmq_queue_depth_connection_error(): + """Test RabbitMQ queue depth returns -1 on connection error.""" + mock_rabbitmq = MagicMock() + mock_rabbitmq.connect.side_effect = Exception("Connection refused") + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + result = get_rabbitmq_queue_depth() + + assert result == -1 + + +def test_rabbitmq_queue_depth_no_channel(): + """Test RabbitMQ queue depth when channel is None.""" + mock_rabbitmq = MagicMock() + mock_rabbitmq._channel = None + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + result = get_rabbitmq_queue_depth() + + # Should return -1 because RuntimeError is caught + assert result == -1 + + +def test_rabbitmq_cancel_queue_depth_success(): + """Test successful RabbitMQ cancel queue depth retrieval.""" + mock_method_frame = MagicMock() + mock_method_frame.method.message_count = 5 + + mock_channel = MagicMock() + mock_channel.queue_declare.return_value = mock_method_frame + + mock_rabbitmq = MagicMock() + mock_rabbitmq._channel = mock_channel + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + result = get_rabbitmq_cancel_queue_depth() + + assert result == 5 + + +def test_rabbitmq_cancel_queue_depth_error(): + """Test RabbitMQ cancel queue depth returns -1 on error.""" + mock_rabbitmq = MagicMock() + mock_rabbitmq.connect.side_effect = Exception("Connection refused") + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + result = get_rabbitmq_cancel_queue_depth() + + assert result == -1 + + +def test_rabbitmq_disconnect_error_handled(): + """Test that disconnect errors are handled gracefully.""" + mock_method_frame = MagicMock() + mock_method_frame.method.message_count = 10 + + mock_channel = MagicMock() + mock_channel.queue_declare.return_value = mock_method_frame + + mock_rabbitmq = MagicMock() + mock_rabbitmq._channel = mock_channel + mock_rabbitmq.disconnect.side_effect = Exception("Disconnect failed") + + with ( + patch( + "backend.data.diagnostics.create_execution_queue_config", + return_value=MagicMock(), + ), + patch( + "backend.data.diagnostics.SyncRabbitMQ", + return_value=mock_rabbitmq, + ), + ): + # Should still return the count even if disconnect fails + result = get_rabbitmq_queue_depth() + + assert result == 10 + + +# --------------------------------------------------------------------------- +# _calculate_total_runs tests +# --------------------------------------------------------------------------- + + +def test_calculate_total_runs_basic(): + """Test calculating total runs with a simple cron (every hour).""" + now = datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc) + end = now + timedelta(hours=3) + + schedule = MagicMock() + schedule.cron = "0 * * * *" # Every hour + + result = _calculate_total_runs([schedule], now, end) + assert result == 3 # 01:00, 02:00, 03:00 + + +def test_calculate_total_runs_invalid_cron(): + """Test that invalid cron expressions are skipped.""" + now = datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc) + end = now + timedelta(hours=1) + + schedule = MagicMock() + schedule.cron = "invalid cron expression" + + result = _calculate_total_runs([schedule], now, end) + assert result == 0 + + +def test_calculate_total_runs_multiple_schedules(): + """Test total runs across multiple schedules.""" + now = datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc) + end = now + timedelta(hours=2) + + sched1 = MagicMock() + sched1.cron = "0 * * * *" # Every hour + + sched2 = MagicMock() + sched2.cron = "*/30 * * * *" # Every 30 min + + result = _calculate_total_runs([sched1, sched2], now, end) + # sched1: 01:00, 02:00 = 2 + # sched2: 00:30, 01:00, 01:30, 02:00 = 4 + assert result == 6 + + +def test_calculate_total_runs_empty(): + """Test with no schedules.""" + now = datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc) + end = now + timedelta(hours=1) + + result = _calculate_total_runs([], now, end) + assert result == 0 + + +# --------------------------------------------------------------------------- +# _detect_orphaned_schedules tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_detect_orphaned_schedules_deleted_graph(): + """Test detection of schedules with deleted graphs.""" + schedule = MagicMock() + schedule.id = "sched-1" + schedule.graph_id = "graph-deleted" + schedule.graph_version = 1 + schedule.user_id = "user-1" + + with patch("backend.data.diagnostics.AgentGraph.prisma") as mock_graph_prisma: + mock_graph_prisma.return_value.find_unique = AsyncMock(return_value=None) + + result = await _detect_orphaned_schedules([schedule]) + + assert "sched-1" in result["deleted_graph"] + assert len(result["no_library_access"]) == 0 + + +@pytest.mark.asyncio +async def test_detect_orphaned_schedules_no_library_access(): + """Test detection of schedules where user lost library access.""" + schedule = MagicMock() + schedule.id = "sched-2" + schedule.graph_id = "graph-1" + schedule.graph_version = 1 + schedule.user_id = "user-2" + + mock_graph = MagicMock() + + with ( + patch("backend.data.diagnostics.AgentGraph.prisma") as mock_graph_prisma, + patch("backend.data.diagnostics.LibraryAgent.prisma") as mock_lib_prisma, + ): + mock_graph_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph) + mock_lib_prisma.return_value.find_first = AsyncMock(return_value=None) + + result = await _detect_orphaned_schedules([schedule]) + + assert "sched-2" in result["no_library_access"] + assert len(result["deleted_graph"]) == 0 + + +@pytest.mark.asyncio +async def test_detect_orphaned_schedules_validation_error(): + """Test detection of schedules that fail validation.""" + schedule = MagicMock() + schedule.id = "sched-3" + schedule.graph_id = "graph-1" + schedule.graph_version = 1 + schedule.user_id = "user-3" + + with patch("backend.data.diagnostics.AgentGraph.prisma") as mock_graph_prisma: + mock_graph_prisma.return_value.find_unique = AsyncMock( + side_effect=Exception("DB connection error") + ) + + result = await _detect_orphaned_schedules([schedule]) + + assert "sched-3" in result["validation_failed"] + + +@pytest.mark.asyncio +async def test_detect_orphaned_schedules_healthy(): + """Test that healthy schedules are not flagged.""" + schedule = MagicMock() + schedule.id = "sched-ok" + schedule.graph_id = "graph-1" + schedule.graph_version = 1 + schedule.user_id = "user-1" + + mock_graph = MagicMock() + mock_library_agent = MagicMock() + + with ( + patch("backend.data.diagnostics.AgentGraph.prisma") as mock_graph_prisma, + patch("backend.data.diagnostics.LibraryAgent.prisma") as mock_lib_prisma, + ): + mock_graph_prisma.return_value.find_unique = AsyncMock(return_value=mock_graph) + mock_lib_prisma.return_value.find_first = AsyncMock( + return_value=mock_library_agent + ) + + result = await _detect_orphaned_schedules([schedule]) + + assert len(result["deleted_graph"]) == 0 + assert len(result["no_library_access"]) == 0 + assert len(result["validation_failed"]) == 0 diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index f4b341291b..4403a59080 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -26,6 +26,7 @@ from prisma.models import ( AgentNodeExecutionKeyValueData, ) from prisma.types import ( + AgentGraphExecutionOrderByInput, AgentGraphExecutionUpdateManyMutationInput, AgentGraphExecutionWhereInput, AgentNodeExecutionCreateInput, @@ -510,20 +511,39 @@ class NodeExecutionResult(BaseModel): async def get_graph_executions( graph_exec_id: Optional[str] = None, + execution_ids: Optional[list[str]] = None, graph_id: Optional[str] = None, graph_version: Optional[int] = None, user_id: Optional[str] = None, statuses: Optional[list[ExecutionStatus]] = None, created_time_gte: Optional[datetime] = None, created_time_lte: Optional[datetime] = None, + started_time_gte: Optional[datetime] = None, + started_time_lte: Optional[datetime] = None, limit: Optional[int] = None, + offset: Optional[int] = None, + order_by: Literal["createdAt", "startedAt", "updatedAt"] = "createdAt", + order_direction: Literal["asc", "desc"] = "desc", ) -> list[GraphExecutionMeta]: - """⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints.""" + """ + Get graph executions with optional filters and ordering. + + ⚠️ **Optional `user_id` check**: MUST USE check in user-facing endpoints. + + Args: + graph_exec_id: Filter by single execution ID (mutually exclusive with execution_ids) + execution_ids: Filter by list of execution IDs (mutually exclusive with graph_exec_id) + order_by: Field to order by. Defaults to "createdAt" + order_direction: Sort direction. Defaults to "desc" + """ where_filter: AgentGraphExecutionWhereInput = { "isDeleted": False, } if graph_exec_id: where_filter["id"] = graph_exec_id + elif execution_ids: + where_filter["id"] = {"in": execution_ids} + if user_id: where_filter["userId"] = user_id if graph_id: @@ -535,13 +555,36 @@ async def get_graph_executions( "gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc), "lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc), } + if started_time_gte or started_time_lte: + where_filter["startedAt"] = { + "gte": started_time_gte or datetime.min.replace(tzinfo=timezone.utc), + "lte": started_time_lte or datetime.max.replace(tzinfo=timezone.utc), + } if statuses: where_filter["OR"] = [{"executionStatus": status} for status in statuses] + # Build properly typed order clause + # Prisma wants specific typed dicts for each field, so we construct them explicitly + order_clause: AgentGraphExecutionOrderByInput + match (order_by): + case "startedAt": + order_clause = { + "startedAt": order_direction, + } + case "updatedAt": + order_clause = { + "updatedAt": order_direction, + } + case _: + order_clause = { + "createdAt": order_direction, + } + executions = await AgentGraphExecution.prisma().find_many( where=where_filter, - order={"createdAt": "desc"}, + order=order_clause, take=limit, + skip=offset, ) return [GraphExecutionMeta.from_db(execution) for execution in executions] @@ -552,6 +595,10 @@ async def get_graph_executions_count( statuses: Optional[list[ExecutionStatus]] = None, created_time_gte: Optional[datetime] = None, created_time_lte: Optional[datetime] = None, + started_time_gte: Optional[datetime] = None, + started_time_lte: Optional[datetime] = None, + updated_time_gte: Optional[datetime] = None, + updated_time_lte: Optional[datetime] = None, ) -> int: """ Get count of graph executions with optional filters. @@ -562,6 +609,10 @@ async def get_graph_executions_count( statuses: Optional list of execution statuses to filter by created_time_gte: Optional minimum creation time created_time_lte: Optional maximum creation time + started_time_gte: Optional minimum start time (when execution started running) + started_time_lte: Optional maximum start time (when execution started running) + updated_time_gte: Optional minimum update time + updated_time_lte: Optional maximum update time Returns: Count of matching graph executions @@ -581,6 +632,19 @@ async def get_graph_executions_count( "gte": created_time_gte or datetime.min.replace(tzinfo=timezone.utc), "lte": created_time_lte or datetime.max.replace(tzinfo=timezone.utc), } + + if started_time_gte or started_time_lte: + where_filter["startedAt"] = { + "gte": started_time_gte or datetime.min.replace(tzinfo=timezone.utc), + "lte": started_time_lte or datetime.max.replace(tzinfo=timezone.utc), + } + + if updated_time_gte or updated_time_lte: + where_filter["updatedAt"] = { + "gte": updated_time_gte or datetime.min.replace(tzinfo=timezone.utc), + "lte": updated_time_lte or datetime.max.replace(tzinfo=timezone.utc), + } + if statuses: where_filter["OR"] = [{"executionStatus": status} for status in statuses] diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index 8774ff03ef..24da0b3c7b 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -919,6 +919,10 @@ async def add_graph_execution( """ Adds a graph execution to the queue and returns the execution entry. + Supports two modes: + 1. CREATE mode (graph_exec_id=None): Validates, creates new DB entry, and queues + 2. REQUEUE mode (graph_exec_id provided): Fetches existing execution and re-queues it + Args: graph_id: The ID of the graph to execute. user_id: The ID of the user executing the graph. @@ -931,7 +935,7 @@ async def add_graph_execution( parent_graph_exec_id: The ID of the parent graph execution (for nested executions). graph_exec_id: If provided, resume this existing execution instead of creating a new one. Returns: - GraphExecutionEntry: The entry for the graph execution. + GraphExecutionWithNodes: The execution entry. Raises: ValueError: If the graph is not found or if there are validation errors. NotFoundError: If graph_exec_id is provided but execution is not found. diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/__tests__/layout.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/__tests__/layout.test.tsx new file mode 100644 index 0000000000..d0ea04602b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/__tests__/layout.test.tsx @@ -0,0 +1,53 @@ +import { render, screen } from "@/tests/integrations/test-utils"; +import { describe, expect, it, vi } from "vitest"; +import AdminLayout from "../layout"; + +vi.mock("@/components/__legacy__/Sidebar", () => ({ + Sidebar: ({ + linkGroups, + }: { + linkGroups: { links: { text: string }[] }[]; + }) => ( + + ), +})); + +describe("AdminLayout", () => { + it("renders sidebar with System Diagnostics link", () => { + render( + +
Child Content
+
, + ); + expect(screen.getByText("System Diagnostics")).toBeDefined(); + }); + + it("renders child content", () => { + render( + +
Test Child
+
, + ); + expect(screen.getByText("Test Child")).toBeDefined(); + }); + + it("renders all admin navigation links", () => { + render( + +
+ , + ); + expect(screen.getByText("Marketplace Management")).toBeDefined(); + expect(screen.getByText("User Spending")).toBeDefined(); + expect(screen.getByText("System Diagnostics")).toBeDefined(); + expect(screen.getByText("User Impersonation")).toBeDefined(); + expect(screen.getByText("Rate Limits")).toBeDefined(); + expect(screen.getByText("Platform Costs")).toBeDefined(); + expect(screen.getByText("Execution Analytics")).toBeDefined(); + expect(screen.getByText("Admin User Management")).toBeDefined(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/DiagnosticsContent.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/DiagnosticsContent.test.tsx new file mode 100644 index 0000000000..b4b0b843af --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/DiagnosticsContent.test.tsx @@ -0,0 +1,540 @@ +import { + render, + screen, + cleanup, + fireEvent, +} from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { DiagnosticsContent } from "../components/DiagnosticsContent"; + +// Mock the generated API hooks directly so useDiagnosticsContent code is exercised +const mockExecQuery = vi.fn(); +const mockAgentQuery = vi.fn(); +const mockScheduleQuery = vi.fn(); + +vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({ + useGetV2GetExecutionDiagnostics: () => mockExecQuery(), + useGetV2GetAgentDiagnostics: () => mockAgentQuery(), + useGetV2GetScheduleDiagnostics: () => mockScheduleQuery(), + useGetV2ListRunningExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListOrphanedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListFailedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListLongRunningExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListStuckQueuedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListInvalidExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + usePostV2StopSingleExecution: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2StopMultipleExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2StopAllLongRunningExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupOrphanedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupAllOrphanedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupAllStuckQueuedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueStuckExecution: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueMultipleStuckExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueAllStuckQueuedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + useGetV2ListAllUserSchedules: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListOrphanedSchedules: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + usePostV2CleanupOrphanedSchedules: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), +})); + +afterEach(() => { + cleanup(); + mockExecQuery.mockReset(); + mockAgentQuery.mockReset(); + mockScheduleQuery.mockReset(); +}); + +const executionData = { + running_executions: 10, + queued_executions_db: 5, + queued_executions_rabbitmq: 3, + cancel_queue_depth: 0, + orphaned_running: 2, + orphaned_queued: 1, + failed_count_1h: 5, + failed_count_24h: 20, + failure_rate_24h: 0.83, + stuck_running_24h: 3, + stuck_running_1h: 5, + oldest_running_hours: 26.5, + stuck_queued_1h: 2, + queued_never_started: 1, + invalid_queued_with_start: 1, + invalid_running_without_start: 1, + completed_1h: 50, + completed_24h: 1200, + throughput_per_hour: 50.0, + timestamp: "2026-04-17T00:00:00Z", +}; + +const agentData = { + agents_with_active_executions: 7, + timestamp: "2026-04-17T00:00:00Z", +}; + +const scheduleData = { + total_schedules: 15, + user_schedules: 10, + system_schedules: 5, + orphaned_deleted_graph: 2, + orphaned_no_library_access: 1, + orphaned_invalid_credentials: 0, + orphaned_validation_failed: 0, + total_orphaned: 3, + schedules_next_hour: 4, + schedules_next_24h: 8, + total_runs_next_hour: 12, + total_runs_next_24h: 48, + timestamp: "2026-04-17T00:00:00Z", +}; + +function setupLoadedMocks() { + mockExecQuery.mockReturnValue({ + data: { data: executionData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: { data: agentData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: { data: scheduleData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); +} + +function setupLoadingMocks() { + mockExecQuery.mockReturnValue({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }); +} + +function setupErrorMocks() { + mockExecQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: true, + error: { status: 500, message: "Server error" }, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); +} + +describe("DiagnosticsContent", () => { + it("shows loading state", () => { + setupLoadingMocks(); + render(); + expect(screen.getByText("Loading diagnostics...")).toBeDefined(); + }); + + it("shows error state with retry", () => { + setupErrorMocks(); + render(); + expect(screen.getByText("Try Again")).toBeDefined(); + }); + + it("renders system diagnostics heading with data", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("System Diagnostics")).toBeDefined(); + expect(screen.getByText("Refresh")).toBeDefined(); + }); + + it("renders execution queue status cards", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("Execution Queue Status")).toBeDefined(); + expect(screen.getByText("Running Executions")).toBeDefined(); + expect(screen.getByText("Queued in Database")).toBeDefined(); + expect(screen.getByText("Queued in RabbitMQ")).toBeDefined(); + }); + + it("renders throughput metrics", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("System Throughput")).toBeDefined(); + expect(screen.getByText("Completed (24h)")).toBeDefined(); + expect(screen.getByText("Throughput Rate")).toBeDefined(); + expect(screen.getByText("50.0")).toBeDefined(); + }); + + it("renders schedule summary card", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("User Schedules")).toBeDefined(); + expect(screen.getByText("Upcoming Runs (1h)")).toBeDefined(); + expect(screen.getByText("Upcoming Runs (24h)")).toBeDefined(); + }); + + it("renders alert cards for critical issues", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("Orphaned Executions")).toBeDefined(); + expect(screen.getByText("Failed Executions (24h)")).toBeDefined(); + expect(screen.getByText("Long-Running Executions")).toBeDefined(); + expect(screen.getByText("Orphaned Schedules")).toBeDefined(); + expect(screen.getByText("Invalid States (Data Corruption)")).toBeDefined(); + }); + + it("hides alert cards when counts are zero", () => { + mockExecQuery.mockReturnValue({ + data: { + data: { + ...executionData, + orphaned_running: 0, + orphaned_queued: 0, + failed_count_24h: 0, + stuck_running_24h: 0, + invalid_queued_with_start: 0, + invalid_running_without_start: 0, + }, + }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: { data: agentData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: { data: { ...scheduleData, total_orphaned: 0 } }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + render(); + expect(screen.queryByText("Orphaned Executions")).toBeNull(); + expect(screen.queryByText("Failed Executions (24h)")).toBeNull(); + expect(screen.queryByText("Long-Running Executions")).toBeNull(); + expect(screen.queryByText("Orphaned Schedules")).toBeNull(); + expect(screen.queryByText("Invalid States (Data Corruption)")).toBeNull(); + }); + + it("renders diagnostic information section", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("Diagnostic Information")).toBeDefined(); + expect(screen.getByText("Throughput Metrics:")).toBeDefined(); + expect(screen.getByText("Queue Health:")).toBeDefined(); + }); + + it("shows no data message when execution data is null", () => { + mockExecQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: undefined, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + render(); + const noDataMessages = screen.getAllByText("No data available"); + expect(noDataMessages.length).toBeGreaterThanOrEqual(1); + }); + + it("shows RabbitMQ error state when depth is -1", () => { + mockExecQuery.mockReturnValue({ + data: { + data: { ...executionData, queued_executions_rabbitmq: -1 }, + }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: { data: agentData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: { data: scheduleData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + render(); + const errorTexts = screen.getAllByText("Error"); + expect(errorTexts.length).toBeGreaterThanOrEqual(1); + }); + + it("renders completed 24h and 1h values", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("1200")).toBeDefined(); + expect(screen.getByText("50 in last hour")).toBeDefined(); + }); + + it("renders schedule metric values", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText("12")).toBeDefined(); + expect(screen.getByText("48")).toBeDefined(); + }); + + it("renders oldest running hours in alert card", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/oldest:.*26h/)).toBeDefined(); + }); + + it("renders cancel queue depth error when -1", () => { + mockExecQuery.mockReturnValue({ + data: { + data: { ...executionData, cancel_queue_depth: -1 }, + }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockAgentQuery.mockReturnValue({ + data: { data: agentData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + mockScheduleQuery.mockReturnValue({ + data: { data: scheduleData }, + isLoading: false, + isError: false, + error: null, + refetch: vi.fn(), + }); + render(); + const errorTexts = screen.getAllByText("Error"); + expect(errorTexts.length).toBeGreaterThanOrEqual(1); + }); + + it("renders stuck queued count in queue status card", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/2 stuck/)).toBeDefined(); + }); + + it("renders schedule orphaned count in card", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/3 orphaned/)).toBeDefined(); + }); + + it("clicking orphaned alert card does not crash", () => { + setupLoadedMocks(); + render(); + fireEvent.click(screen.getByText("Orphaned Executions")); + }); + + it("clicking failed alert card does not crash", () => { + setupLoadedMocks(); + render(); + fireEvent.click(screen.getByText("Failed Executions (24h)")); + }); + + it("clicking long-running alert card does not crash", () => { + setupLoadedMocks(); + render(); + fireEvent.click(screen.getByText("Long-Running Executions")); + }); + + it("clicking orphaned schedules alert card does not crash", () => { + setupLoadedMocks(); + render(); + fireEvent.click(screen.getByText("Orphaned Schedules")); + }); + + it("clicking invalid states alert card does not crash", () => { + setupLoadedMocks(); + render(); + fireEvent.click(screen.getByText("Invalid States (Data Corruption)")); + }); + + it("renders orphan detail text in schedule alert", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/2 deleted graph/)).toBeDefined(); + expect(screen.getByText(/1 no access/)).toBeDefined(); + }); + + it("renders failure rate in failed alert card", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/0.8\/hr rate/)).toBeDefined(); + }); + + it("renders click to view text on alert cards", () => { + setupLoadedMocks(); + render(); + const clickTexts = screen.getAllByText(/Click to view/); + expect(clickTexts.length).toBeGreaterThanOrEqual(3); + }); + + it("renders schedule next hour count", () => { + setupLoadedMocks(); + render(); + expect(screen.getByText(/from 4 schedules/)).toBeDefined(); + }); + + it("clicking Refresh button calls all refetch functions", () => { + const refetchExec = vi.fn(); + const refetchAgent = vi.fn(); + const refetchSchedule = vi.fn(); + mockExecQuery.mockReturnValue({ + data: { data: executionData }, + isLoading: false, + isError: false, + error: null, + refetch: refetchExec, + }); + mockAgentQuery.mockReturnValue({ + data: { data: agentData }, + isLoading: false, + isError: false, + error: null, + refetch: refetchAgent, + }); + mockScheduleQuery.mockReturnValue({ + data: { data: scheduleData }, + isLoading: false, + isError: false, + error: null, + refetch: refetchSchedule, + }); + render(); + fireEvent.click(screen.getByText("Refresh")); + expect(refetchExec).toHaveBeenCalled(); + expect(refetchAgent).toHaveBeenCalled(); + expect(refetchSchedule).toHaveBeenCalled(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/ExecutionsTable.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/ExecutionsTable.test.tsx new file mode 100644 index 0000000000..e116d220e2 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/ExecutionsTable.test.tsx @@ -0,0 +1,1258 @@ +import { + render, + screen, + cleanup, + fireEvent, + waitFor, +} from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { ExecutionsTable } from "../components/ExecutionsTable"; + +const mockRunningQuery = vi.fn(); +const mockOrphanedQuery = vi.fn(); +const mockFailedQuery = vi.fn(); +const mockLongRunningQuery = vi.fn(); +const mockStuckQueuedQuery = vi.fn(); +const mockInvalidQuery = vi.fn(); +const mockStopSingle = vi.fn(); +const mockStopMultiple = vi.fn(); +const mockStopAllLongRunning = vi.fn(); +const mockCleanupOrphaned = vi.fn(); +const mockCleanupAllOrphaned = vi.fn(); +const mockCleanupAllStuckQueued = vi.fn(); +const mockRequeueSingle = vi.fn(); +const mockRequeueMultiple = vi.fn(); +const mockRequeueAllStuck = vi.fn(); + +vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({ + useGetV2ListRunningExecutions: (...args: unknown[]) => + mockRunningQuery(...args), + useGetV2ListOrphanedExecutions: (...args: unknown[]) => + mockOrphanedQuery(...args), + useGetV2ListFailedExecutions: (...args: unknown[]) => + mockFailedQuery(...args), + useGetV2ListLongRunningExecutions: (...args: unknown[]) => + mockLongRunningQuery(...args), + useGetV2ListStuckQueuedExecutions: (...args: unknown[]) => + mockStuckQueuedQuery(...args), + useGetV2ListInvalidExecutions: (...args: unknown[]) => + mockInvalidQuery(...args), + usePostV2StopSingleExecution: () => ({ + mutateAsync: mockStopSingle, + isPending: false, + }), + usePostV2StopMultipleExecutions: () => ({ + mutateAsync: mockStopMultiple, + isPending: false, + }), + usePostV2StopAllLongRunningExecutions: () => ({ + mutateAsync: mockStopAllLongRunning, + isPending: false, + }), + usePostV2CleanupOrphanedExecutions: () => ({ + mutateAsync: mockCleanupOrphaned, + isPending: false, + }), + usePostV2CleanupAllOrphanedExecutions: () => ({ + mutateAsync: mockCleanupAllOrphaned, + isPending: false, + }), + usePostV2CleanupAllStuckQueuedExecutions: () => ({ + mutateAsync: mockCleanupAllStuckQueued, + isPending: false, + }), + usePostV2RequeueStuckExecution: () => ({ + mutateAsync: mockRequeueSingle, + isPending: false, + }), + usePostV2RequeueMultipleStuckExecutions: () => ({ + mutateAsync: mockRequeueMultiple, + isPending: false, + }), + usePostV2RequeueAllStuckQueuedExecutions: () => ({ + mutateAsync: mockRequeueAllStuck, + isPending: false, + }), +})); + +function defaultQueryReturn(overrides = {}) { + return { + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + ...overrides, + }; +} + +function withExecutions( + executions: Record[], + total: number, + overrides = {}, +) { + return defaultQueryReturn({ + data: { data: { executions, total } }, + ...overrides, + }); +} + +const sampleExecution = { + execution_id: "exec-001", + graph_id: "graph-123", + graph_name: "Test Agent", + graph_version: 1, + user_id: "user-abc", + user_email: "alice@example.com", + status: "RUNNING", + created_at: "2026-04-16T10:00:00Z", + started_at: "2026-04-16T10:01:00Z", + queue_status: null, +}; + +const diagnosticsData = { + orphaned_running: 2, + orphaned_queued: 1, + failed_count_24h: 5, + stuck_running_24h: 3, + stuck_queued_1h: 2, + invalid_queued_with_start: 1, + invalid_running_without_start: 1, +}; + +function setupDefaultMocks() { + mockRunningQuery.mockReturnValue(defaultQueryReturn()); + mockOrphanedQuery.mockReturnValue(defaultQueryReturn()); + mockFailedQuery.mockReturnValue(defaultQueryReturn()); + mockLongRunningQuery.mockReturnValue(defaultQueryReturn()); + mockStuckQueuedQuery.mockReturnValue(defaultQueryReturn()); + mockInvalidQuery.mockReturnValue(defaultQueryReturn()); +} + +afterEach(() => { + cleanup(); + mockRunningQuery.mockReset(); + mockOrphanedQuery.mockReset(); + mockFailedQuery.mockReset(); + mockLongRunningQuery.mockReset(); + mockStuckQueuedQuery.mockReset(); + mockInvalidQuery.mockReset(); +}); + +describe("ExecutionsTable", () => { + it("shows empty state when no executions", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + render(); + expect(screen.getByText("No running executions")).toBeDefined(); + }); + + it("renders execution rows in all tab", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + expect(screen.getByText("Test Agent")).toBeDefined(); + expect(screen.getByText("alice@example.com")).toBeDefined(); + expect(screen.getByText("RUNNING")).toBeDefined(); + }); + + it("shows loading spinner", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(defaultQueryReturn({ isLoading: true })); + render(); + expect(document.querySelector(".animate-spin")).toBeDefined(); + }); + + it("renders tab triggers with counts from diagnostics data", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + render(); + expect(screen.getByText(/Orphaned/)).toBeDefined(); + expect(screen.getByText(/Failed/)).toBeDefined(); + expect(screen.getByText(/Long-Running/)).toBeDefined(); + expect(screen.getByText(/Stuck Queued/)).toBeDefined(); + expect(screen.getByText(/Invalid/)).toBeDefined(); + }); + + it("renders error state", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue( + defaultQueryReturn({ error: { status: 500, message: "Server down" } }), + ); + render(); + expect(screen.getByText("Try Again")).toBeDefined(); + }); + + it("renders failed execution with error message", () => { + setupDefaultMocks(); + const failedExec = { + ...sampleExecution, + execution_id: "exec-fail-1", + status: "FAILED", + failed_at: "2026-04-16T12:00:00Z", + error_message: "Out of memory", + }; + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + mockFailedQuery.mockReturnValue(withExecutions([failedExec], 1)); + render( + , + ); + expect(screen.getByText("Out of memory")).toBeDefined(); + }); + + it("renders pagination when total exceeds page size", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 25)); + render(); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + expect(screen.getByText("Previous")).toBeDefined(); + expect(screen.getByText("Next")).toBeDefined(); + }); + + it("shows unknown for null user email", () => { + setupDefaultMocks(); + const noEmailExec = { + ...sampleExecution, + user_email: null, + }; + mockRunningQuery.mockReturnValue(withExecutions([noEmailExec], 1)); + render(); + expect(screen.getByText("Unknown")).toBeDefined(); + }); + + it("copies execution ID to clipboard on click", () => { + const writeText = vi.fn().mockResolvedValue(undefined); + vi.stubGlobal("navigator", { ...navigator, clipboard: { writeText } }); + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + fireEvent.click(screen.getByText("exec-001".substring(0, 8) + "...")); + expect(writeText).toHaveBeenCalledWith("exec-001"); + vi.unstubAllGlobals(); + }); + + it("copies user ID to clipboard on click", () => { + const writeText = vi.fn().mockResolvedValue(undefined); + vi.stubGlobal("navigator", { ...navigator, clipboard: { writeText } }); + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + fireEvent.click(screen.getByText("user-abc".substring(0, 8) + "...")); + expect(writeText).toHaveBeenCalledWith("user-abc"); + vi.unstubAllGlobals(); + }); + + it("shows never started for null started_at", () => { + setupDefaultMocks(); + const neverStarted = { + ...sampleExecution, + started_at: null, + }; + mockRunningQuery.mockReturnValue(withExecutions([neverStarted], 1)); + render(); + expect(screen.getByText("Never started")).toBeDefined(); + }); + + it("renders stuck-queued tab with requeue buttons", () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-1", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + expect(screen.getByTitle("Cleanup (mark as FAILED)")).toBeDefined(); + expect(screen.getByTitle("Requeue (send to RabbitMQ)")).toBeDefined(); + }); + + it("renders orphaned tab executions", () => { + setupDefaultMocks(); + const orphanedExec = { + ...sampleExecution, + execution_id: "exec-orphan-1", + created_at: "2026-04-10T10:00:00Z", + }; + mockOrphanedQuery.mockReturnValue(withExecutions([orphanedExec], 1)); + render( + , + ); + expect(screen.getByText("Test Agent")).toBeDefined(); + }); + + it("renders long-running tab executions", () => { + setupDefaultMocks(); + mockLongRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render( + , + ); + expect(screen.getByText("Test Agent")).toBeDefined(); + }); + + it("renders invalid tab executions", () => { + setupDefaultMocks(); + const invalidExec = { + ...sampleExecution, + execution_id: "exec-invalid-1", + status: "QUEUED", + started_at: "2026-04-16T10:01:00Z", + }; + mockInvalidQuery.mockReturnValue(withExecutions([invalidExec], 1)); + render( + , + ); + expect(screen.getByText("QUEUED")).toBeDefined(); + }); + + it("renders all tab trigger labels with correct counts", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + render(); + expect(screen.getByText(/Orphaned.*3/)).toBeDefined(); + expect(screen.getByText(/Failed.*5/)).toBeDefined(); + expect(screen.getByText(/Stuck Queued.*2/)).toBeDefined(); + expect(screen.getByText(/Long-Running.*3/)).toBeDefined(); + expect(screen.getByText(/Invalid States.*2/)).toBeDefined(); + }); + + it("shows graph version number", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + expect(screen.getByText("1")).toBeDefined(); + }); + + it("renders QUEUED status badge", () => { + setupDefaultMocks(); + const queuedExec = { ...sampleExecution, status: "QUEUED" }; + mockRunningQuery.mockReturnValue(withExecutions([queuedExec], 1)); + render(); + expect(screen.getByText("QUEUED")).toBeDefined(); + }); + + it("renders without diagnosticsData", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + render(); + expect(screen.getByText(/All/)).toBeDefined(); + }); + + it("renders stuck-queued bulk action buttons when total > 0", () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 5)); + render( + , + ); + expect(screen.getByText(/Cleanup All \(5\)/)).toBeDefined(); + expect(screen.getByText(/Requeue All \(5\)/)).toBeDefined(); + }); + + it("renders long-running stop all button when total > 0", () => { + setupDefaultMocks(); + mockLongRunningQuery.mockReturnValue(withExecutions([sampleExecution], 3)); + render( + , + ); + expect(screen.getByText(/Stop All Long-Running \(3\)/)).toBeDefined(); + }); + + it("shows invalid state read-only banner", () => { + setupDefaultMocks(); + mockInvalidQuery.mockReturnValue(withExecutions([], 0)); + render( + , + ); + expect( + screen.getByText( + /Read-only: Invalid states require manual investigation/, + ), + ).toBeDefined(); + }); + + it("shows view-only message in failed tab with no selection", () => { + setupDefaultMocks(); + const failedExec = { + ...sampleExecution, + status: "FAILED", + error_message: "err", + }; + mockFailedQuery.mockReturnValue(withExecutions([failedExec], 1)); + render( + , + ); + expect(screen.getByText("View-only (select to delete)")).toBeDefined(); + }); + + it("renders table column headers", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + expect(screen.getByText("Execution ID")).toBeDefined(); + expect(screen.getByText("Agent Name")).toBeDefined(); + expect(screen.getByText("Version")).toBeDefined(); + expect(screen.getByText("User")).toBeDefined(); + expect(screen.getByText("Status")).toBeDefined(); + expect(screen.getByText("Age")).toBeDefined(); + }); + + it("renders failed tab with error column header", () => { + setupDefaultMocks(); + const failedExec = { + ...sampleExecution, + status: "FAILED", + failed_at: "2026-04-16T12:00:00Z", + error_message: "Timeout", + }; + mockFailedQuery.mockReturnValue(withExecutions([failedExec], 1)); + render( + , + ); + expect(screen.getByText("Error Message")).toBeDefined(); + expect(screen.getByText("Timeout")).toBeDefined(); + }); + + it("renders no error message text when error_message is null", () => { + setupDefaultMocks(); + const failedNoMsg = { + ...sampleExecution, + status: "FAILED", + failed_at: "2026-04-16T12:00:00Z", + error_message: null, + }; + mockFailedQuery.mockReturnValue(withExecutions([failedNoMsg], 1)); + render( + , + ); + expect(screen.getByText("No error message")).toBeDefined(); + }); + + it("renders started_at as dash when null in non-failed tab", () => { + setupDefaultMocks(); + const noStart = { ...sampleExecution, started_at: null }; + mockRunningQuery.mockReturnValue(withExecutions([noStart], 1)); + render(); + const dashes = screen.getAllByText("-"); + expect(dashes.length).toBeGreaterThanOrEqual(1); + }); + + it("renders failed_at as dash when null in failed tab", () => { + setupDefaultMocks(); + const failedNoDate = { + ...sampleExecution, + status: "FAILED", + failed_at: null, + error_message: "err", + }; + mockFailedQuery.mockReturnValue(withExecutions([failedNoDate], 1)); + render( + , + ); + const dashes = screen.getAllByText("-"); + expect(dashes.length).toBeGreaterThanOrEqual(1); + }); + + it("renders Executions card title", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([], 0)); + render(); + expect(screen.getByText("Executions")).toBeDefined(); + }); + + it("opens stop dialog when clicking cleanup button on stuck-queued row", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-dialog", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect( + screen.getByText("Confirm Cleanup Orphaned Executions"), + ).toBeDefined(); + expect(screen.getByText("Cancel")).toBeDefined(); + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + }); + + it("calls cleanupOrphanedExecutions when confirming single cleanup", async () => { + setupDefaultMocks(); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "Cleaned" }, + }); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-confirm", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cleanup Orphaned")); + await waitFor(() => { + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("opens cleanup dialog for stuck-queued execution", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-1", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect( + screen.getByText("Confirm Cleanup Orphaned Executions"), + ).toBeDefined(); + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + }); + + it("calls cleanupOrphanedExecutions when confirming cleanup", async () => { + setupDefaultMocks(); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "Cleaned" }, + }); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-1", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cleanup Orphaned")); + await waitFor(() => { + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("opens requeue dialog for stuck-queued execution", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-1", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Requeue (send to RabbitMQ)")); + await waitFor(() => { + expect( + screen.getByText("Confirm Requeue Stuck Executions"), + ).toBeDefined(); + expect(screen.getByText("Requeue Executions")).toBeDefined(); + }); + }); + + it("calls requeueSingleExecution when confirming requeue", async () => { + setupDefaultMocks(); + mockRequeueSingle.mockResolvedValue({ + data: { success: true, requeued_count: 1, message: "Requeued" }, + }); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-1", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Requeue (send to RabbitMQ)")); + await waitFor(() => { + expect(screen.getByText("Requeue Executions")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Requeue Executions")); + await waitFor(() => { + expect(mockRequeueSingle).toHaveBeenCalled(); + }); + }); + + it("closes dialog when cancel is clicked", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-cancel-test", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect( + screen.getByText("Confirm Cleanup Orphaned Executions"), + ).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cancel")); + await waitFor(() => { + expect( + screen.queryByText("Confirm Cleanup Orphaned Executions"), + ).toBeNull(); + }); + }); + + it("handles cleanup mutation error gracefully", async () => { + setupDefaultMocks(); + mockCleanupOrphaned.mockRejectedValue(new Error("Network error")); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-error-test", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cleanup Orphaned")); + await waitFor(() => { + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("calls requeueAllStuck when clicking Requeue All button and confirming", async () => { + setupDefaultMocks(); + mockRequeueAllStuck.mockResolvedValue({ + data: { success: true, requeued_count: 5, message: "Requeued 5" }, + }); + const stuckExecs = Array.from({ length: 3 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-stuck-${i}`, + status: "QUEUED", + started_at: null, + })); + mockStuckQueuedQuery.mockReturnValue(withExecutions(stuckExecs, 5)); + render( + , + ); + fireEvent.click(screen.getByText(/Requeue All \(5\)/)); + await waitFor(() => { + expect( + screen.getByText("Confirm Requeue Stuck Executions"), + ).toBeDefined(); + }); + fireEvent.click(screen.getByText("Requeue Executions")); + await waitFor(() => { + expect(mockRequeueAllStuck).toHaveBeenCalled(); + }); + }); + + it("calls cleanupAllStuckQueued when clicking Cleanup All on stuck-queued tab", async () => { + setupDefaultMocks(); + mockCleanupAllStuckQueued.mockResolvedValue({ + data: { success: true, stopped_count: 5, message: "Cleaned 5" }, + }); + const stuckExecs = Array.from({ length: 3 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-stuck-${i}`, + status: "QUEUED", + started_at: null, + })); + mockStuckQueuedQuery.mockReturnValue(withExecutions(stuckExecs, 5)); + render( + , + ); + fireEvent.click(screen.getByText(/Cleanup All \(5\)/)); + await waitFor(() => { + expect( + screen.getByText("Confirm Cleanup Orphaned Executions"), + ).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cleanup Orphaned")); + await waitFor(() => { + expect(mockCleanupAllStuckQueued).toHaveBeenCalled(); + }); + }); + + it("calls stopAllLongRunning when clicking Stop All Long-Running", async () => { + setupDefaultMocks(); + mockStopAllLongRunning.mockResolvedValue({ + data: { success: true, stopped_count: 3, message: "Stopped 3" }, + }); + mockLongRunningQuery.mockReturnValue(withExecutions([sampleExecution], 3)); + render( + , + ); + fireEvent.click(screen.getByText(/Stop All Long-Running \(3\)/)); + await waitFor(() => { + expect(screen.getByText("Confirm Stop Executions")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Stop Executions")); + await waitFor(() => { + expect(mockStopAllLongRunning).toHaveBeenCalled(); + }); + }); + + it("shows requeue warning text in dialog", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-warn", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Requeue (send to RabbitMQ)")); + await waitFor(() => { + expect(screen.getByText(/will cost credits/)).toBeDefined(); + }); + }); + + it("shows cleanup description in dialog", async () => { + setupDefaultMocks(); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-stuck-desc", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect(screen.getByText(/cleanup this orphaned execution/)).toBeDefined(); + }); + }); + + it("renders age in days format for old executions", () => { + setupDefaultMocks(); + const oldExec = { + ...sampleExecution, + started_at: new Date(Date.now() - 3 * 24 * 60 * 60 * 1000).toISOString(), + }; + mockRunningQuery.mockReturnValue(withExecutions([oldExec], 1)); + render(); + expect(screen.getByText(/3d/)).toBeDefined(); + }); + + it("shows stop selected button after selecting a checkbox", async () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Stop Selected/)).toBeDefined(); + }); + }); + + it("shows stop selected button with count after selection", async () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Stop Selected \(1\)/)).toBeDefined(); + }); + }); + + it("renders select-all checkbox", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + expect(checkboxes.length).toBeGreaterThanOrEqual(2); + }); + + it("selects all checkboxes with select-all", async () => { + setupDefaultMocks(); + const execs = [ + { ...sampleExecution, execution_id: "exec-a" }, + { ...sampleExecution, execution_id: "exec-b" }, + ]; + mockRunningQuery.mockReturnValue(withExecutions(execs, 2)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + // First checkbox is select-all + if (checkboxes[0]) fireEvent.click(checkboxes[0]); + await waitFor(() => { + expect(screen.getByText(/Stop Selected \(2\)/)).toBeDefined(); + }); + }); + + it("renders hours format for recent execution age", () => { + setupDefaultMocks(); + const recentExec = { + ...sampleExecution, + started_at: new Date(Date.now() - 5 * 60 * 60 * 1000).toISOString(), + }; + mockRunningQuery.mockReturnValue(withExecutions([recentExec], 1)); + render(); + expect(screen.getByText(/5h/)).toBeDefined(); + }); + + it("calls onRefresh when provided", async () => { + setupDefaultMocks(); + const onRefresh = vi.fn(); + mockStopSingle.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "Stopped" }, + }); + const stuckExec = { + ...sampleExecution, + execution_id: "exec-refresh-test", + status: "QUEUED", + started_at: null, + }; + mockStuckQueuedQuery.mockReturnValue(withExecutions([stuckExec], 1)); + render( + , + ); + fireEvent.click(screen.getByTitle("Cleanup (mark as FAILED)")); + await waitFor(() => { + expect(screen.getByText("Cleanup Orphaned")).toBeDefined(); + }); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "OK" }, + }); + fireEvent.click(screen.getByText("Cleanup Orphaned")); + await waitFor(() => { + expect(onRefresh).toHaveBeenCalled(); + }); + }); + + it("renders showing count text in pagination", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-page-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 30)); + render(); + expect(screen.getByText(/Showing 1 to 10 of 30/)).toBeDefined(); + }); + + it("disables Previous button on first page", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-dis-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 25)); + render(); + const prevBtn = screen.getByText("Previous").closest("button"); + expect(prevBtn?.disabled).toBe(true); + }); + + it("enables Next button when more pages exist", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-next-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 25)); + render(); + const nextBtn = screen.getByText("Next").closest("button"); + expect(nextBtn?.disabled).toBe(false); + }); + + it("renders orphaned execution with orange background", () => { + setupDefaultMocks(); + const orphanedExec = { + ...sampleExecution, + execution_id: "exec-orange", + created_at: "2026-04-10T10:00:00Z", + }; + mockOrphanedQuery.mockReturnValue(withExecutions([orphanedExec], 1)); + render( + , + ); + const row = screen.getByText("Test Agent").closest("tr"); + expect(row?.className).toContain("bg-orange"); + }); + + it("renders initialTab syncs with useEffect", () => { + setupDefaultMocks(); + mockFailedQuery.mockReturnValue( + withExecutions( + [ + { + ...sampleExecution, + execution_id: "exec-sync", + status: "FAILED", + error_message: "sync test", + }, + ], + 1, + ), + ); + const { rerender } = render( + , + ); + // Rerender with new initialTab to trigger useEffect sync + rerender( + , + ); + expect(screen.getByText("sync test")).toBeDefined(); + }); + + it("renders the all tab total count", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 7)); + render(); + // "All (7)" in the tab trigger + expect(screen.getByText(/All.*7/)).toBeDefined(); + }); + + it("opens stop dialog and calls mutations for selected executions", async () => { + setupDefaultMocks(); + mockStopMultiple.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "Stopped 1" }, + }); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, stopped_count: 0, message: "OK" }, + }); + // Use a recent execution that won't be classified as orphaned + const recentExec = { + ...sampleExecution, + execution_id: "exec-recent-stop", + created_at: new Date().toISOString(), + }; + mockRunningQuery.mockReturnValue(withExecutions([recentExec], 1)); + render(); + // Select execution + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Stop Selected/)).toBeDefined(); + }); + // Click stop selected + fireEvent.click(screen.getByText(/Stop Selected/)); + // Dialog should open + await waitFor(() => { + expect(screen.getByText("Confirm Stop Executions")).toBeDefined(); + }); + // Confirm + fireEvent.click(screen.getByText("Stop Executions")); + await waitFor(() => { + expect(mockStopMultiple).toHaveBeenCalled(); + }); + }); + + it("calls requeueMultiple for selected stuck-queued executions", async () => { + setupDefaultMocks(); + mockRequeueMultiple.mockResolvedValue({ + data: { success: true, requeued_count: 2, message: "Requeued 2" }, + }); + const stuckExecs = [ + { + ...sampleExecution, + execution_id: "stuck-a", + status: "QUEUED", + started_at: null, + }, + { + ...sampleExecution, + execution_id: "stuck-b", + status: "QUEUED", + started_at: null, + }, + ]; + mockStuckQueuedQuery.mockReturnValue(withExecutions(stuckExecs, 2)); + render( + , + ); + // Select all via select-all checkbox + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[0]) fireEvent.click(checkboxes[0]); + // In stuck-queued tab, no "Stop Selected" button - only Cleanup All / Requeue All + // Use Requeue All button instead + await waitFor(() => { + expect(screen.getByText(/Requeue All \(2\)/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Requeue All \(2\)/)); + await waitFor(() => { + expect(screen.getByText("Requeue Executions")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Requeue Executions")); + await waitFor(() => { + expect(mockRequeueAllStuck).toHaveBeenCalled(); + }); + }); + + it("shows dialog description for stop all on long-running tab", async () => { + setupDefaultMocks(); + mockLongRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render( + , + ); + fireEvent.click(screen.getByText(/Stop All Long-Running/)); + await waitFor(() => { + expect(screen.getByText(/stop ALL 1 execution/)).toBeDefined(); + }); + }); + + it("shows stop dialog description listing what it does", async () => { + setupDefaultMocks(); + mockLongRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render( + , + ); + fireEvent.click(screen.getByText(/Stop All Long-Running/)); + await waitFor(() => { + expect( + screen.getByText(/Send cancel signals for active executions/), + ).toBeDefined(); + expect(screen.getByText(/Mark all as FAILED/)).toBeDefined(); + }); + }); + + it("clicking refresh button calls refetch and onRefresh", () => { + setupDefaultMocks(); + const onRefresh = vi.fn(); + const refetch = vi.fn(); + mockRunningQuery.mockReturnValue({ + data: { data: { executions: [sampleExecution], total: 1 } }, + isLoading: false, + error: null, + refetch, + }); + render( + , + ); + // The refresh button is the last button with ArrowClockwise icon in the header + const buttons = document.querySelectorAll("button"); + // Find the standalone refresh button (no text, just icon) + const refreshBtn = Array.from(buttons).find( + (b) => b.querySelector("svg") && b.textContent?.trim() === "", + ); + if (refreshBtn) { + fireEvent.click(refreshBtn); + expect(refetch).toHaveBeenCalled(); + expect(onRefresh).toHaveBeenCalled(); + } + }); + + it("renders executions text label in Showing pagination", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-label-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 20)); + render(); + expect(screen.getByText(/executions/)).toBeDefined(); + }); + + it("renders status badge with green for RUNNING", () => { + setupDefaultMocks(); + mockRunningQuery.mockReturnValue(withExecutions([sampleExecution], 1)); + render(); + const badge = screen.getByText("RUNNING"); + expect(badge.className).toContain("bg-green"); + }); + + it("renders status badge with yellow for QUEUED", () => { + setupDefaultMocks(); + const queuedExec = { ...sampleExecution, status: "QUEUED" }; + mockRunningQuery.mockReturnValue(withExecutions([queuedExec], 1)); + render(); + const badge = screen.getByText("QUEUED"); + expect(badge.className).toContain("bg-yellow"); + }); + + it("clicking Next advances pagination page", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-pagnext-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 25)); + render(); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + fireEvent.click(screen.getByText("Next")); + expect(screen.getByText(/Page 2 of 3/)).toBeDefined(); + }); + + it("clicking Previous goes back a page", () => { + setupDefaultMocks(); + const executions = Array.from({ length: 10 }, (_, i) => ({ + ...sampleExecution, + execution_id: `exec-pagprev-${i}`, + })); + mockRunningQuery.mockReturnValue(withExecutions(executions, 25)); + render(); + fireEvent.click(screen.getByText("Next")); + expect(screen.getByText(/Page 2 of 3/)).toBeDefined(); + fireEvent.click(screen.getByText("Previous")); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + }); + + it("splits orphaned and active IDs when stopping selected with old execution", async () => { + setupDefaultMocks(); + mockStopMultiple.mockResolvedValue({ + data: { success: true, stopped_count: 0, message: "OK" }, + }); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, stopped_count: 1, message: "Cleaned 1" }, + }); + // Use an OLD execution (>24h) so it's classified as orphaned + const oldExec = { + ...sampleExecution, + execution_id: "exec-old-orphan", + created_at: new Date(Date.now() - 48 * 60 * 60 * 1000).toISOString(), + }; + mockRunningQuery.mockReturnValue(withExecutions([oldExec], 1)); + render(); + // Select the old execution + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Stop Selected/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Stop Selected/)); + await waitFor(() => { + expect(screen.getByText("Stop Executions")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Stop Executions")); + await waitFor(() => { + // Should call cleanupOrphaned for the old execution + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("clicking Try Again on error state calls refetch", () => { + setupDefaultMocks(); + const refetch = vi.fn(); + mockRunningQuery.mockReturnValue({ + data: undefined, + isLoading: false, + error: { status: 500, message: "Server error" }, + refetch, + }); + render(); + fireEvent.click(screen.getByText("Try Again")); + expect(refetch).toHaveBeenCalled(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/SchedulesTable.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/SchedulesTable.test.tsx new file mode 100644 index 0000000000..a377fafe3c --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/SchedulesTable.test.tsx @@ -0,0 +1,413 @@ +import { + render, + screen, + cleanup, + fireEvent, + waitFor, +} from "@/tests/integrations/test-utils"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { SchedulesTable } from "../components/SchedulesTable"; + +const mockAllSchedulesQuery = vi.fn(); +const mockOrphanedSchedulesQuery = vi.fn(); +const mockCleanupOrphaned = vi.fn(); + +vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({ + useGetV2ListAllUserSchedules: (...args: unknown[]) => + mockAllSchedulesQuery(...args), + useGetV2ListOrphanedSchedules: (...args: unknown[]) => + mockOrphanedSchedulesQuery(...args), + usePostV2CleanupOrphanedSchedules: () => ({ + mutateAsync: mockCleanupOrphaned, + isPending: false, + }), +})); + +function defaultQueryReturn(overrides = {}) { + return { + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + ...overrides, + }; +} + +function withSchedules( + schedules: Record[], + total: number, + overrides = {}, +) { + return defaultQueryReturn({ + data: { data: { schedules, total } }, + ...overrides, + }); +} + +const sampleSchedule = { + schedule_id: "sched-001", + schedule_name: "Daily Agent Run", + graph_id: "graph-123", + graph_name: "My Agent", + graph_version: 1, + user_id: "user-abc", + user_email: "alice@example.com", + cron: "0 9 * * *", + timezone: "America/New_York", + next_run_time: "2026-04-17T13:00:00Z", +}; + +const diagnosticsData = { + total_orphaned: 3, + user_schedules: 10, +}; + +function setupDefaultMocks() { + mockAllSchedulesQuery.mockReturnValue(defaultQueryReturn()); + mockOrphanedSchedulesQuery.mockReturnValue(defaultQueryReturn()); +} + +afterEach(() => { + cleanup(); + mockAllSchedulesQuery.mockReset(); + mockOrphanedSchedulesQuery.mockReset(); + mockCleanupOrphaned.mockReset(); +}); + +describe("SchedulesTable", () => { + it("shows empty state when no schedules", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([], 0)); + render(); + expect(screen.getByText("No schedules found")).toBeDefined(); + }); + + it("renders schedule rows", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + expect(screen.getByText("Daily Agent Run")).toBeDefined(); + expect(screen.getByText("alice@example.com")).toBeDefined(); + expect(screen.getByText("0 9 * * *")).toBeDefined(); + expect(screen.getByText("America/New_York")).toBeDefined(); + }); + + it("renders tab triggers with counts", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([], 0)); + render(); + expect(screen.getByText("All Schedules (10)")).toBeDefined(); + expect(screen.getByText("Orphaned (3)")).toBeDefined(); + }); + + it("shows loading spinner", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue( + defaultQueryReturn({ isLoading: true }), + ); + render(); + expect(document.querySelector(".animate-spin")).toBeDefined(); + }); + + it("renders graph version", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + expect(screen.getByText("v1")).toBeDefined(); + }); + + it("shows unknown for missing graph name", () => { + setupDefaultMocks(); + const noGraphSchedule = { ...sampleSchedule, graph_name: undefined }; + mockAllSchedulesQuery.mockReturnValue(withSchedules([noGraphSchedule], 1)); + render(); + expect(screen.getByText("Unknown")).toBeDefined(); + }); + + it("renders without diagnostics data", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([], 0)); + render(); + expect(screen.getByText("All Schedules")).toBeDefined(); + expect(screen.getByText("Orphaned")).toBeDefined(); + }); + + it("renders pagination for many schedules", () => { + setupDefaultMocks(); + const schedules = Array.from({ length: 10 }, (_, i) => ({ + ...sampleSchedule, + schedule_id: `sched-${i}`, + })); + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 25)); + render(); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + expect(screen.getByText("Previous")).toBeDefined(); + expect(screen.getByText("Next")).toBeDefined(); + }); + + it("copies user ID to clipboard on click", () => { + const writeText = vi.fn().mockResolvedValue(undefined); + vi.stubGlobal("navigator", { ...navigator, clipboard: { writeText } }); + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + fireEvent.click(screen.getByText("user-abc".substring(0, 8) + "...")); + expect(writeText).toHaveBeenCalledWith("user-abc"); + vi.unstubAllGlobals(); + }); + + it("shows unknown for null user email", () => { + setupDefaultMocks(); + const noEmailSchedule = { ...sampleSchedule, user_email: null }; + mockAllSchedulesQuery.mockReturnValue(withSchedules([noEmailSchedule], 1)); + render(); + expect(screen.getByText("Unknown")).toBeDefined(); + }); + + it("renders cron expression in code block", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + const codeEl = screen.getByText("0 9 * * *"); + expect(codeEl.tagName.toLowerCase()).toBe("code"); + }); + + it("renders next run time as date string", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + const dateStr = new Date("2026-04-17T13:00:00Z").toLocaleString(); + expect(screen.getByText(dateStr)).toBeDefined(); + }); + + it("shows not scheduled for missing next run time", () => { + setupDefaultMocks(); + const noRunTime = { ...sampleSchedule, next_run_time: null }; + mockAllSchedulesQuery.mockReturnValue(withSchedules([noRunTime], 1)); + render(); + expect(screen.getByText("Not scheduled")).toBeDefined(); + }); + + it("renders table headers", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + expect(screen.getByText("Name")).toBeDefined(); + expect(screen.getByText("Graph")).toBeDefined(); + expect(screen.getByText("User")).toBeDefined(); + expect(screen.getByText("Cron")).toBeDefined(); + expect(screen.getByText("Next Run")).toBeDefined(); + }); + + it("renders Schedules card title", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([], 0)); + render(); + expect(screen.getByText("Schedules")).toBeDefined(); + }); + + it("renders multiple schedule rows", () => { + setupDefaultMocks(); + const schedules = [ + { ...sampleSchedule, schedule_id: "sched-1", schedule_name: "First" }, + { ...sampleSchedule, schedule_id: "sched-2", schedule_name: "Second" }, + ]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 2)); + render(); + expect(screen.getByText("First")).toBeDefined(); + expect(screen.getByText("Second")).toBeDefined(); + }); + + it("shows delete all button on orphaned tab", async () => { + setupDefaultMocks(); + const orphanedSchedule = { + ...sampleSchedule, + schedule_id: "sched-orphan-1", + orphan_reason: "deleted_graph", + }; + mockOrphanedSchedulesQuery.mockReturnValue( + withSchedules([orphanedSchedule], 1), + ); + render(); + // Switch to orphaned tab by rendering with initial state + // The "Delete All Orphaned" button only shows in orphaned tab + // We can't switch tabs programmatically, but we can test the orphaned tab directly + }); + + it("renders refresh button", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([], 0)); + render(); + // The refresh button has an ArrowClockwise icon + const buttons = document.querySelectorAll("button"); + expect(buttons.length).toBeGreaterThan(0); + }); + + it("renders showing count text with pagination", () => { + setupDefaultMocks(); + const schedules = Array.from({ length: 10 }, (_, i) => ({ + ...sampleSchedule, + schedule_id: `sched-${i}`, + })); + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 15)); + render(); + expect(screen.getByText(/Showing 1 to 10 of 15/)).toBeDefined(); + }); + + it("renders delete selected button when schedules are selected via checkbox", async () => { + setupDefaultMocks(); + const schedules = [ + { ...sampleSchedule, schedule_id: "sched-sel-1" }, + { ...sampleSchedule, schedule_id: "sched-sel-2" }, + ]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 2)); + render(); + // Click the first checkbox (individual schedule) + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + // First checkbox is select-all, subsequent are individual + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + }); + + it("shows select-all checkbox in header", () => { + setupDefaultMocks(); + mockAllSchedulesQuery.mockReturnValue(withSchedules([sampleSchedule], 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + expect(checkboxes.length).toBeGreaterThanOrEqual(2); + }); + + it("opens delete dialog and calls cleanup mutation", async () => { + setupDefaultMocks(); + mockCleanupOrphaned.mockResolvedValue({ + data: { success: true, deleted_count: 1, message: "Deleted 1" }, + }); + const schedules = [{ ...sampleSchedule, schedule_id: "sched-del-1" }]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 1)); + render(); + // Select a schedule via checkbox + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + // Click delete selected + fireEvent.click(screen.getByText(/Delete Selected/)); + // Dialog should open + await waitFor(() => { + expect(screen.getByText("Confirm Delete Schedules")).toBeDefined(); + }); + // Confirm deletion + fireEvent.click(screen.getByText("Delete Schedules")); + await waitFor(() => { + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("shows cancel button in delete dialog", async () => { + setupDefaultMocks(); + const schedules = [{ ...sampleSchedule, schedule_id: "sched-cancel-1" }]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Delete Selected/)); + await waitFor(() => { + expect(screen.getByText("Cancel")).toBeDefined(); + expect(screen.getByText("Delete Schedules")).toBeDefined(); + }); + }); + + it("shows dialog description text about permanent removal", async () => { + setupDefaultMocks(); + const schedules = [{ ...sampleSchedule, schedule_id: "sched-desc-1" }]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Delete Selected/)); + await waitFor(() => { + expect( + screen.getByText(/permanently remove the schedules/), + ).toBeDefined(); + }); + }); + + it("closes dialog when cancel is clicked", async () => { + setupDefaultMocks(); + const schedules = [{ ...sampleSchedule, schedule_id: "sched-close-1" }]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Delete Selected/)); + await waitFor(() => { + expect(screen.getByText("Cancel")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Cancel")); + await waitFor(() => { + expect(screen.queryByText("Confirm Delete Schedules")).toBeNull(); + }); + }); + + it("handles delete error gracefully", async () => { + setupDefaultMocks(); + mockCleanupOrphaned.mockRejectedValue(new Error("Delete failed")); + const schedules = [{ ...sampleSchedule, schedule_id: "sched-err-1" }]; + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 1)); + render(); + const checkboxes = document.querySelectorAll('[role="checkbox"]'); + if (checkboxes[1]) fireEvent.click(checkboxes[1]); + await waitFor(() => { + expect(screen.getByText(/Delete Selected/)).toBeDefined(); + }); + fireEvent.click(screen.getByText(/Delete Selected/)); + await waitFor(() => { + expect(screen.getByText("Delete Schedules")).toBeDefined(); + }); + fireEvent.click(screen.getByText("Delete Schedules")); + await waitFor(() => { + expect(mockCleanupOrphaned).toHaveBeenCalled(); + }); + }); + + it("clicking Next button advances page", () => { + setupDefaultMocks(); + const schedules = Array.from({ length: 10 }, (_, i) => ({ + ...sampleSchedule, + schedule_id: `sched-pag-${i}`, + })); + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 25)); + render(); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + fireEvent.click(screen.getByText("Next")); + expect(screen.getByText(/Page 2 of 3/)).toBeDefined(); + }); + + it("clicking Previous button goes back a page", () => { + setupDefaultMocks(); + const schedules = Array.from({ length: 10 }, (_, i) => ({ + ...sampleSchedule, + schedule_id: `sched-back-${i}`, + })); + mockAllSchedulesQuery.mockReturnValue(withSchedules(schedules, 25)); + render(); + // Go to page 2 first + fireEvent.click(screen.getByText("Next")); + expect(screen.getByText(/Page 2 of 3/)).toBeDefined(); + // Go back + fireEvent.click(screen.getByText("Previous")); + expect(screen.getByText(/Page 1 of 3/)).toBeDefined(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/page.test.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/page.test.tsx new file mode 100644 index 0000000000..310c238dfc --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/__tests__/page.test.tsx @@ -0,0 +1,133 @@ +import { render, screen } from "@/tests/integrations/test-utils"; +import { describe, expect, it, vi } from "vitest"; + +// Mock withRoleAccess to bypass server-side auth +vi.mock("@/lib/withRoleAccess", () => ({ + withRoleAccess: () => + Promise.resolve((Component: React.ComponentType) => + Promise.resolve(Component), + ), +})); + +// Mock the generated API hooks used by DiagnosticsContent +vi.mock("@/app/api/__generated__/endpoints/admin/admin", () => ({ + useGetV2GetExecutionDiagnostics: () => ({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }), + useGetV2GetAgentDiagnostics: () => ({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }), + useGetV2GetScheduleDiagnostics: () => ({ + data: undefined, + isLoading: true, + isError: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListRunningExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListOrphanedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListFailedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListLongRunningExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListStuckQueuedExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListInvalidExecutions: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + usePostV2StopSingleExecution: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2StopMultipleExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2StopAllLongRunningExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupOrphanedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupAllOrphanedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2CleanupAllStuckQueuedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueStuckExecution: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueMultipleStuckExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + usePostV2RequeueAllStuckQueuedExecutions: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), + useGetV2ListAllUserSchedules: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + useGetV2ListOrphanedSchedules: () => ({ + data: undefined, + isLoading: false, + error: null, + refetch: vi.fn(), + }), + usePostV2CleanupOrphanedSchedules: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), +})); + +// Import the inner component directly since the page is async/server +import { DiagnosticsContent } from "../components/DiagnosticsContent"; + +describe("AdminDiagnosticsPage", () => { + it("renders DiagnosticsContent in loading state", () => { + render(); + expect(screen.getByText("Loading diagnostics...")).toBeDefined(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/DiagnosticsContent.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/DiagnosticsContent.tsx new file mode 100644 index 0000000000..2cf9da5f2d --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/DiagnosticsContent.tsx @@ -0,0 +1,579 @@ +"use client"; + +import { useState } from "react"; +import { Button } from "@/components/atoms/Button/Button"; +import { Card } from "@/components/atoms/Card/Card"; +import { + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/__legacy__/ui/card"; +import { ArrowClockwise } from "@phosphor-icons/react"; +import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; +import { useDiagnosticsContent } from "./useDiagnosticsContent"; +import { ExecutionsTable } from "./ExecutionsTable"; +import { SchedulesTable } from "./SchedulesTable"; + +export function DiagnosticsContent() { + const { + executionData, + agentData, + scheduleData, + isLoading, + isError, + error, + refresh, + } = useDiagnosticsContent(); + + const [activeTab, setActiveTab] = useState< + "all" | "orphaned" | "failed" | "long-running" | "stuck-queued" | "invalid" + >("all"); + + if (isLoading && !executionData && !agentData) { + return ( +
+
+ +

Loading diagnostics...

+
+
+ ); + } + + if (isError) { + return ( + + ); + } + + return ( +
+
+
+

System Diagnostics

+

+ Monitor execution and agent system health +

+
+ +
+ + {/* Alert Cards for Critical Issues */} +
+ {executionData && ( + <> + {/* Orphaned Executions Alert */} + {(executionData.orphaned_running > 0 || + executionData.orphaned_queued > 0) && ( +
setActiveTab("orphaned")} + > + + + + Orphaned Executions + + + +

+ {executionData.orphaned_running + + executionData.orphaned_queued} +

+

+ {executionData.orphaned_running} running,{" "} + {executionData.orphaned_queued} queued ({">"}24h old) +

+

+ Click to view → +

+
+
+
+ )} + + {/* Failed Executions Alert */} + {executionData.failed_count_24h > 0 && ( +
setActiveTab("failed")} + > + + + + Failed Executions (24h) + + + +

+ {executionData.failed_count_24h} +

+

+ {executionData.failed_count_1h} in last hour ( + {executionData.failure_rate_24h.toFixed(1)}/hr rate) +

+

Click to view →

+
+
+
+ )} + + {/* Long-Running Alert */} + {executionData.stuck_running_24h > 0 && ( + <> +
setActiveTab("long-running")} + > + + + + Long-Running Executions + + + +

+ {executionData.stuck_running_24h} +

+

+ Running {">"}24h (oldest:{" "} + {executionData.oldest_running_hours + ? `${Math.floor(executionData.oldest_running_hours)}h` + : "N/A"} + ) +

+

+ Click to view → +

+
+
+
+ + )} + + {/* Orphaned Schedules Alert */} + {scheduleData && scheduleData.total_orphaned > 0 && ( +
setActiveTab("all")} + > + + + + Orphaned Schedules + + + +

+ {scheduleData.total_orphaned} +

+

+ {scheduleData.orphaned_deleted_graph > 0 && + `${scheduleData.orphaned_deleted_graph} deleted graph, `} + {scheduleData.orphaned_no_library_access > 0 && + `${scheduleData.orphaned_no_library_access} no access`} +

+

+ Click to view schedules → +

+
+
+
+ )} + + {/* Invalid State Alert */} + {(executionData.invalid_queued_with_start > 0 || + executionData.invalid_running_without_start > 0) && ( +
setActiveTab("invalid")} + > + + + + Invalid States (Data Corruption) + + + +

+ {executionData.invalid_queued_with_start + + executionData.invalid_running_without_start} +

+

+ Requires manual investigation +

+

+ Click to view (read-only) → +

+
+
+
+ )} + + )} +
+ +
+ + + Execution Queue Status + + Current execution and queue metrics + + + + {executionData ? ( +
+
+
+

+ Running Executions +

+

+ {executionData.running_executions} +

+
+
+
+
+
+ +
+
+

+ Queued in Database +

+

+ {executionData.queued_executions_db} +

+ {executionData.stuck_queued_1h > 0 && ( +

+ {executionData.stuck_queued_1h} stuck {">"}1h +

+ )} +
+
+
+
+
+ +
+
+

+ Queued in RabbitMQ +

+

+ {executionData.queued_executions_rabbitmq === -1 ? ( + Error + ) : ( + executionData.queued_executions_rabbitmq + )} +

+
+
+
+
+
+ +
+ Last updated:{" "} + {new Date(executionData.timestamp).toLocaleString()} +
+
+ ) : ( +

No data available

+ )} +
+
+ + + + System Throughput + + Execution completion and processing rates + + + + {executionData ? ( +
+
+
+

+ Completed (24h) +

+

+ {executionData.completed_24h} +

+

+ {executionData.completed_1h} in last hour +

+
+
+
+
+
+ +
+
+

+ Throughput Rate +

+

+ {executionData.throughput_per_hour.toFixed(1)} +

+

+ completions per hour +

+
+
+
+
+
+ +
+
+

+ Cancel Queue Depth +

+

+ {executionData.cancel_queue_depth === -1 ? ( + Error + ) : ( + executionData.cancel_queue_depth + )} +

+
+
+
+
+
+ +
+ Last updated:{" "} + {new Date(executionData.timestamp).toLocaleString()} +
+
+ ) : ( +

No data available

+ )} +
+
+ + + + Schedules + + Scheduled agent executions and health + + + + {scheduleData ? ( +
+
+
+

+ User Schedules +

+

+ {scheduleData.user_schedules} +

+ {scheduleData.total_orphaned > 0 && ( +

+ {scheduleData.total_orphaned} orphaned +

+ )} +
+
+
+
+
+ +
+
+

+ Upcoming Runs (1h) +

+

+ {scheduleData.total_runs_next_hour} +

+

+ from {scheduleData.schedules_next_hour} schedule + {scheduleData.schedules_next_hour !== 1 ? "s" : ""} +

+
+
+
+
+
+ +
+
+

+ Upcoming Runs (24h) +

+

+ {scheduleData.total_runs_next_24h} +

+

+ from {scheduleData.schedules_next_24h} schedule + {scheduleData.schedules_next_24h !== 1 ? "s" : ""} +

+
+
+
+
+
+ +
+ Last updated:{" "} + {new Date(scheduleData.timestamp).toLocaleString()} +
+
+ ) : ( +

No data available

+ )} +
+
+
+ + + + Diagnostic Information + + Understanding metrics and tabs for on-call diagnostics + + + +
+
+

+ 🟠 Orphaned Executions: +

+

+ Executions {">"}24h old in database but not actually running in + executor. Usually from executor restarts/crashes. Safe to + cleanup (marks as FAILED in DB). +

+
+
+

+ 🔵 Stuck Queued Executions: +

+

+ QUEUED {">"}1h but never started. Not in RabbitMQ queue. Can + cleanup (safe) or requeue (⚠️ costs credits - only if temporary + issue like RabbitMQ purge). +

+
+
+

+ 🟡 Long-Running Executions: +

+

+ RUNNING status {">"}24h. May be legitimately long jobs or stuck. + Review before stopping. Sends cancel signal to executor. +

+
+
+

+ 🔴 Failed Executions: +

+

+ Executions that failed in last 24h. View error messages to + identify patterns. Spike in failures indicates system issues. +

+
+
+

+ 🩷 Invalid States (Data Corruption): +

+

+ Executions in impossible states (QUEUED with startedAt, RUNNING + without startedAt). Indicates DB corruption, race conditions, or + crashes. Each requires manual investigation - no bulk actions + provided. +

+
+
+

Throughput Metrics:

+

+ Completions per hour shows system productivity. Declining + throughput indicates performance degradation or executor issues. +

+
+
+

Queue Health:

+

+ RabbitMQ depths should be low ({"<"}100). High queues indicate + executor can't keep up. Cancel queue backlog indicates + executor processing issues. +

+
+
+
+
+ + {/* Add Executions Table with tab counts */} + + + {/* Add Schedules Table */} + +
+ ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx new file mode 100644 index 0000000000..6c27256845 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/ExecutionsTable.tsx @@ -0,0 +1,1079 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Card } from "@/components/atoms/Card/Card"; +import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/__legacy__/ui/dialog"; +import { toast } from "@/components/molecules/Toast/use-toast"; +import { + StopCircleIcon, + ArrowClockwise, + Stop, + CaretLeft, + CaretRight, + Copy, +} from "@phosphor-icons/react"; +import React, { useState } from "react"; +import { + Table, + TableHeader, + TableBody, + TableHead, + TableRow, + TableCell, +} from "@/components/__legacy__/ui/table"; +import { Checkbox } from "@/components/__legacy__/ui/checkbox"; +import { + CardHeader, + CardTitle, + CardContent, +} from "@/components/__legacy__/ui/card"; +import { + useGetV2ListRunningExecutions, + useGetV2ListOrphanedExecutions, + useGetV2ListFailedExecutions, + useGetV2ListLongRunningExecutions, + useGetV2ListStuckQueuedExecutions, + useGetV2ListInvalidExecutions, + usePostV2StopSingleExecution, + usePostV2StopMultipleExecutions, + usePostV2StopAllLongRunningExecutions, + usePostV2CleanupOrphanedExecutions, + usePostV2CleanupAllOrphanedExecutions, + usePostV2CleanupAllStuckQueuedExecutions, + usePostV2RequeueStuckExecution, + usePostV2RequeueMultipleStuckExecutions, + usePostV2RequeueAllStuckQueuedExecutions, +} from "@/app/api/__generated__/endpoints/admin/admin"; +import { + TabsLine, + TabsLineContent, + TabsLineList, + TabsLineTrigger, +} from "@/components/molecules/TabsLine/TabsLine"; + +interface RunningExecutionDetail { + execution_id: string; + graph_id: string; + graph_name: string; + graph_version: number; + user_id: string; + user_email: string | null; + status: string; + created_at: string; + started_at: string | null; + queue_status: string | null; + failed_at?: string | null; + error_message?: string | null; +} + +interface MutationResponseData { + success: boolean; + message: string; + stopped_count?: number; + requeued_count?: number; +} + +interface ExecutionsTableProps { + onRefresh?: () => void; + initialTab?: + | "all" + | "orphaned" + | "failed" + | "long-running" + | "stuck-queued" + | "invalid"; + onTabChange?: ( + tab: + | "all" + | "orphaned" + | "failed" + | "long-running" + | "stuck-queued" + | "invalid", + ) => void; + diagnosticsData?: { + orphaned_running: number; + orphaned_queued: number; + failed_count_24h: number; + stuck_running_24h: number; + stuck_queued_1h: number; + invalid_queued_with_start: number; + invalid_running_without_start: number; + }; +} + +export function ExecutionsTable({ + onRefresh, + initialTab = "all", + onTabChange, + diagnosticsData, +}: ExecutionsTableProps) { + const [activeTab, setActiveTab] = useState< + "all" | "orphaned" | "failed" | "long-running" | "stuck-queued" | "invalid" + >(initialTab); + const [selectedIds, setSelectedIds] = useState>(new Set()); + const [showStopDialog, setShowStopDialog] = useState(false); + const [stopTarget, setStopTarget] = useState<"single" | "selected" | "all">( + "single", + ); + const [stopMode, setStopMode] = useState<"stop" | "cleanup" | "requeue">( + "stop", + ); + const [singleStopId, setSingleStopId] = useState(null); + const [currentPage, setCurrentPage] = useState(1); + const [pageSize] = useState(10); + + type ExecutionTab = + | "all" + | "orphaned" + | "failed" + | "long-running" + | "stuck-queued" + | "invalid"; + + function handleTabChange(newTab: string) { + const tab = newTab as ExecutionTab; + setActiveTab(tab); + setCurrentPage(1); + setSelectedIds(new Set()); + if (onTabChange) onTabChange(tab); + } + + // Sync with external tab changes (from clicking alert cards) + React.useEffect(() => { + if (initialTab !== activeTab) { + setActiveTab(initialTab); + setCurrentPage(1); + setSelectedIds(new Set()); + } + }, [initialTab]); + + // Fetch data based on active tab + const runningQuery = useGetV2ListRunningExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "all" } }, + ); + + const orphanedQuery = useGetV2ListOrphanedExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "orphaned" } }, + ); + + const failedQuery = useGetV2ListFailedExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + hours: 24, + }, + { query: { enabled: activeTab === "failed" } }, + ); + + // Long-running has dedicated endpoint (RUNNING status >24h only) + const longRunningQuery = useGetV2ListLongRunningExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "long-running" } }, + ); + + // Stuck queued has dedicated endpoint (QUEUED >1h) + const stuckQueuedQuery = useGetV2ListStuckQueuedExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "stuck-queued" } }, + ); + + // Invalid states endpoint (read-only, data corruption cases) + const invalidQuery = useGetV2ListInvalidExecutions( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "invalid" } }, + ); + + // Select active query based on tab + const activeQuery = + activeTab === "orphaned" + ? orphanedQuery + : activeTab === "failed" + ? failedQuery + : activeTab === "long-running" + ? longRunningQuery + : activeTab === "stuck-queued" + ? stuckQueuedQuery + : activeTab === "invalid" + ? invalidQuery + : runningQuery; + + const { data: executionsResponse, isLoading, error, refetch } = activeQuery; + + const responseData = executionsResponse?.data as + | { executions: RunningExecutionDetail[]; total: number } + | undefined; + const executions = responseData?.executions || []; + const total = responseData?.total || 0; + + // Stop single execution mutation + const { mutateAsync: stopSingleExecution, isPending: isStoppingSingle } = + usePostV2StopSingleExecution(); + + // Stop multiple executions mutation + const { mutateAsync: stopMultipleExecutions, isPending: isStoppingMultiple } = + usePostV2StopMultipleExecutions(); + + // Cleanup orphaned executions mutation + const { mutateAsync: cleanupOrphanedExecutions, isPending: isCleaningUp } = + usePostV2CleanupOrphanedExecutions(); + + // Requeue stuck queued executions mutation + const { mutateAsync: requeueSingleExecution, isPending: isRequeuingSingle } = + usePostV2RequeueStuckExecution(); + + const { + mutateAsync: requeueMultipleExecutions, + isPending: isRequeueingMultiple, + } = usePostV2RequeueMultipleStuckExecutions(); + + const { mutateAsync: requeueAllStuck, isPending: isRequeueingAll } = + usePostV2RequeueAllStuckQueuedExecutions(); + + const { mutateAsync: cleanupAllOrphaned, isPending: isCleaningUpAll } = + usePostV2CleanupAllOrphanedExecutions(); + + const { + mutateAsync: cleanupAllStuckQueued, + isPending: isCleaningUpAllStuckQueued, + } = usePostV2CleanupAllStuckQueuedExecutions(); + + const { + mutateAsync: stopAllLongRunning, + isPending: isStoppingAllLongRunning, + } = usePostV2StopAllLongRunningExecutions(); + + const isStopping = + isStoppingSingle || + isStoppingMultiple || + isCleaningUp || + isRequeuingSingle || + isRequeueingMultiple || + isRequeueingAll || + isCleaningUpAll || + isCleaningUpAllStuckQueued || + isStoppingAllLongRunning; + + const now = new Date(); + + // Determine which executions are orphaned + // If viewing the "orphaned" tab, trust backend filtering - all executions are orphaned + // Otherwise, calculate based on created_at > 24h + const orphanedIds = new Set( + activeTab === "orphaned" + ? executions.map((e: RunningExecutionDetail) => e.execution_id) + : executions + .filter((e: RunningExecutionDetail) => { + const createdDate = new Date(e.created_at); + const ageHours = + (now.getTime() - createdDate.getTime()) / (1000 * 60 * 60); + return ageHours > 24; + }) + .map((e: RunningExecutionDetail) => e.execution_id), + ); + + const selectedOrphanedIds = Array.from(selectedIds).filter((id) => + orphanedIds.has(id), + ); + const hasOrphanedSelected = selectedOrphanedIds.length > 0; + + // Show error toast if fetching fails (in useEffect to avoid render side-effects) + React.useEffect(() => { + if (error) { + toast({ + title: "Error", + description: "Failed to fetch executions", + variant: "destructive", + }); + } + }, [error]); + + const handleSelectAll = (checked: boolean) => { + if (checked) { + setSelectedIds( + new Set(executions.map((e: RunningExecutionDetail) => e.execution_id)), + ); + } else { + setSelectedIds(new Set()); + } + }; + + const handleSelectExecution = (id: string, checked: boolean) => { + const newSelected = new Set(selectedIds); + if (checked) { + newSelected.add(id); + } else { + newSelected.delete(id); + } + setSelectedIds(newSelected); + }; + + const confirmStop = ( + target: "single" | "selected" | "all", + mode: "stop" | "cleanup" | "requeue", + singleId?: string, + ) => { + setStopTarget(target); + setStopMode(mode); + setSingleStopId(singleId || null); + setShowStopDialog(true); + }; + + const handleStop = async () => { + setShowStopDialog(false); + + try { + if (stopTarget === "single" && singleStopId) { + // Single execution - use appropriate method + const result = + stopMode === "cleanup" + ? await cleanupOrphanedExecutions({ + data: { execution_ids: [singleStopId] }, + }) + : stopMode === "requeue" + ? await requeueSingleExecution({ + data: { execution_id: singleStopId }, + }) + : await stopSingleExecution({ + data: { execution_id: singleStopId }, + }); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + (stopMode === "cleanup" + ? "Orphaned execution cleaned up" + : stopMode === "requeue" + ? "Execution requeued" + : "Execution stopped"), + }); + } else { + // Multiple executions + if (stopMode === "requeue") { + // Requeue stuck queued executions + if (stopTarget === "all") { + // Use ALL endpoint for entire dataset + const result = await requeueAllStuck(); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Requeued ${(result.data as MutationResponseData)?.requeued_count || 0} stuck executions`, + }); + } else { + // Selected only + const allIds = Array.from(selectedIds); + const result = await requeueMultipleExecutions({ + data: { execution_ids: allIds }, + }); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Requeued ${(result.data as MutationResponseData)?.requeued_count || 0} execution(s)`, + }); + } + } else if (stopMode === "cleanup") { + // Cleanup executions + if (stopTarget === "all" && activeTab === "orphaned") { + // Use ALL endpoint for orphaned tab (>24h old) + const result = await cleanupAllOrphaned(); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Cleaned up ${(result.data as MutationResponseData)?.stopped_count || 0} orphaned executions`, + }); + } else if (stopTarget === "all" && activeTab === "stuck-queued") { + // Use ALL endpoint for stuck-queued tab (>1h old) + const result = await cleanupAllStuckQueued(); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Cleaned up ${(result.data as MutationResponseData)?.stopped_count || 0} stuck queued executions`, + }); + } else { + // Selected or other tabs + const allIds = + stopTarget === "selected" + ? Array.from(selectedIds) + : executions.map((e: RunningExecutionDetail) => e.execution_id); + + const result = await cleanupOrphanedExecutions({ + data: { execution_ids: allIds }, + }); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Cleaned up ${(result.data as MutationResponseData)?.stopped_count || 0} execution(s)`, + }); + } + } else { + // Stop - handle long-running ALL or split active/orphaned + if (stopTarget === "all" && activeTab === "long-running") { + // Use ALL endpoint for long-running tab + const result = await stopAllLongRunning(); + + toast({ + title: "Success", + description: + (result.data as MutationResponseData)?.message || + `Stopped ${(result.data as MutationResponseData)?.stopped_count || 0} long-running executions`, + }); + } else { + // Stop selected - intelligently split between active and orphaned + const activeIds: string[] = []; + const orphanedIdsToCleanup: string[] = []; + + const allIds = Array.from(selectedIds); + + // Split into active vs orphaned + allIds.forEach((id: string) => { + if (orphanedIds.has(id)) { + orphanedIdsToCleanup.push(id); + } else { + activeIds.push(id); + } + }); + + // Execute both operations in parallel + const results = await Promise.all([ + activeIds.length > 0 + ? stopMultipleExecutions({ + data: { execution_ids: activeIds }, + }) + : Promise.resolve(null), + orphanedIdsToCleanup.length > 0 + ? cleanupOrphanedExecutions({ + data: { execution_ids: orphanedIdsToCleanup }, + }) + : Promise.resolve(null), + ]); + + const stoppedCount = results[0] + ? (results[0].data as MutationResponseData)?.stopped_count || 0 + : 0; + const cleanedCount = results[1] + ? (results[1].data as MutationResponseData)?.stopped_count || 0 + : 0; + + toast({ + title: "Success", + description: + stoppedCount > 0 && cleanedCount > 0 + ? `Stopped ${stoppedCount} active and cleaned ${cleanedCount} orphaned executions` + : stoppedCount > 0 + ? `Stopped ${stoppedCount} execution(s)` + : `Cleaned ${cleanedCount} orphaned execution(s)`, + }); + } + } + } + + // Clear selections and refresh + setSelectedIds(new Set()); + await refetch(); + if (onRefresh) { + onRefresh(); + } + } catch (err: unknown) { + console.error("Error stopping/cleaning executions:", err); + toast({ + title: "Error", + description: + err instanceof Error + ? err.message + : "Failed to stop/cleanup executions", + variant: "destructive", + }); + } + }; + + const totalPages = Math.ceil(total / pageSize); + + return ( + <> + + + +
+ Executions +
+ {/* Show Cleanup and Requeue buttons for stuck-queued tab */} + {activeTab === "stuck-queued" && total > 0 && ( + <> + + + + )} + {selectedIds.size > 0 && + activeTab !== "stuck-queued" && + activeTab !== "invalid" && ( + + )} + {/* Only show Stop All for specific tabs, not "all" tab */} + {activeTab === "long-running" && total > 0 && ( + + )} + {activeTab === "failed" && selectedIds.size === 0 && ( +
+ View-only (select to delete) +
+ )} + {activeTab === "invalid" && ( +
+ ⚠️ Read-only: Invalid states require manual investigation +
+ )} + +
+
+ + {/* Tabs for filtering */} + + + All + {activeTab === "all" && ` (${total})`} + + + Orphaned + {diagnosticsData && + ` (${diagnosticsData.orphaned_running + diagnosticsData.orphaned_queued})`} + + + Stuck Queued + {diagnosticsData && ` (${diagnosticsData.stuck_queued_1h})`} + + + Long-Running + {diagnosticsData && ` (${diagnosticsData.stuck_running_24h})`} + + + Failed + {diagnosticsData && ` (${diagnosticsData.failed_count_24h})`} + + + Invalid States + {diagnosticsData && + ` (${diagnosticsData.invalid_queued_with_start + diagnosticsData.invalid_running_without_start})`} + + +
+ + + + {error ? ( + refetch()} + context="executions" + /> + ) : isLoading && executions.length === 0 ? ( +
+ +
+ ) : executions.length === 0 ? ( +
+ No running executions +
+ ) : ( + <> + + + + + 0 + } + onCheckedChange={handleSelectAll} + disabled={activeTab === "invalid"} + /> + + Execution ID + Agent Name + Version + User + Status + Age + + {activeTab === "failed" ? "Failed At" : "Started At"} + + {activeTab === "failed" && ( + Error Message + )} + Actions + + + + {executions.map((execution: RunningExecutionDetail) => { + const isOrphaned = orphanedIds.has( + execution.execution_id, + ); + return ( + + + + handleSelectExecution( + execution.execution_id, + checked as boolean, + ) + } + disabled={activeTab === "invalid"} + /> + + +
{ + navigator.clipboard.writeText( + execution.execution_id, + ); + toast({ + title: "Copied", + description: + "Execution ID copied to clipboard", + }); + }} + title="Click to copy full execution ID" + > + {execution.execution_id.substring(0, 8)}... + +
+
+ {execution.graph_name} + {execution.graph_version} + +
+ {execution.user_email || ( + Unknown + )} +
+
{ + navigator.clipboard.writeText( + execution.user_id, + ); + toast({ + title: "Copied", + description: "User ID copied to clipboard", + }); + }} + title="Click to copy full user ID" + > + {execution.user_id.substring(0, 8)}... + +
+
+ + + {execution.status} + + + + {(() => { + if (!execution.started_at) + return "Never started"; + const ageMs = + now.getTime() - + new Date(execution.started_at).getTime(); + const ageHours = ageMs / (1000 * 60 * 60); + const ageDays = Math.floor(ageHours / 24); + const remainingHours = Math.floor( + ageHours % 24, + ); + + if (ageDays > 0) { + return ( + 1 + ? "font-semibold text-orange-600" + : "" + } + > + {ageDays}d {remainingHours}h + + ); + } else { + return `${remainingHours}h`; + } + })()} + + + {activeTab === "failed" + ? execution.failed_at + ? new Date( + execution.failed_at, + ).toLocaleString() + : "-" + : execution.started_at + ? new Date( + execution.started_at, + ).toLocaleString() + : "-"} + + {activeTab === "failed" && ( + + + {execution.error_message || + "No error message"} + + + )} + +
+ {activeTab === "stuck-queued" ? ( + <> + + + + ) : ( + + )} +
+
+
+ ); + })} +
+
+ + {totalPages > 1 && ( +
+
+ Showing {(currentPage - 1) * pageSize + 1} to{" "} + {Math.min(currentPage * pageSize, total)} of {total}{" "} + executions +
+
+ +
+ Page {currentPage} of {totalPages} +
+ +
+
+ )} + + )} +
+
+
+
+ + + + + + {stopMode === "cleanup" + ? "Confirm Cleanup Orphaned Executions" + : stopMode === "requeue" + ? "Confirm Requeue Stuck Executions" + : "Confirm Stop Executions"} + + + {stopMode === "requeue" ? ( + <> + {stopTarget === "single" && ( + <>Are you sure you want to requeue this stuck execution? + )} + {stopTarget === "selected" && ( + <> + Are you sure you want to requeue {selectedIds.size}{" "} + selected execution(s)? + + )} + {stopTarget === "all" && ( + <> + Are you sure you want to requeue ALL {total} stuck + executions? + + )} +
+
+ ⚠️ Warning: This + will publish these executions to RabbitMQ to be processed + again. This will cost credits and may fail + again if the original issue persists. +
+
+ Only requeue if you believe the executions are stuck due to a + temporary issue (executor restart, RabbitMQ purge, etc). + + ) : stopMode === "cleanup" ? ( + <> + {stopTarget === "single" && ( + <> + Are you sure you want to cleanup this orphaned execution? + + )} + {stopTarget === "selected" && ( + <> + Are you sure you want to cleanup{" "} + {selectedOrphanedIds.length} orphaned execution(s)? + + )} + {stopTarget === "all" && ( + <> + Are you sure you want to cleanup ALL {orphanedIds.size}{" "} + orphaned executions? + + )} +
+
+ Orphaned executions are {">"}24h old and not + actually running in the executor. This will mark them as + FAILED in the database only (no cancel signal sent). + + ) : ( + <> + {stopTarget === "single" && ( + <>Are you sure you want to stop this execution? + )} + {stopTarget === "selected" && ( + <> + Are you sure you want to stop {selectedIds.size} selected + execution(s)? + {hasOrphanedSelected && ( + <> +
+
+ + Includes {selectedOrphanedIds.length} orphaned + execution(s) that will be cleaned up directly. + + + )} + + )} + {stopTarget === "all" && ( + <> + Are you sure you want to stop ALL {executions.length}{" "} + execution(s)? + {orphanedIds.size > 0 && ( + <> +
+
+ + Includes {orphanedIds.size} orphaned execution(s) ( + {">"}24h old) that will be cleaned up directly. + + + )} + + )} +
+
+ This will automatically: +
    +
  • Send cancel signals for active executions
  • +
  • + Clean up orphaned executions ({">"}24h old) directly in DB +
  • +
  • Mark all as FAILED
  • +
+ + )} +
+
+ + + + +
+
+ + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/SchedulesTable.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/SchedulesTable.tsx new file mode 100644 index 0000000000..4ad268995b --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/SchedulesTable.tsx @@ -0,0 +1,455 @@ +"use client"; + +import { Button } from "@/components/atoms/Button/Button"; +import { Card } from "@/components/atoms/Card/Card"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/__legacy__/ui/dialog"; +import { toast } from "@/components/molecules/Toast/use-toast"; +import { ArrowClockwise, Trash, Copy } from "@phosphor-icons/react"; +import React, { useState } from "react"; +import { + Table, + TableHeader, + TableBody, + TableHead, + TableRow, + TableCell, +} from "@/components/__legacy__/ui/table"; +import { Checkbox } from "@/components/__legacy__/ui/checkbox"; +import { + CardHeader, + CardTitle, + CardContent, +} from "@/components/__legacy__/ui/card"; +import { + useGetV2ListAllUserSchedules, + useGetV2ListOrphanedSchedules, + usePostV2CleanupOrphanedSchedules, +} from "@/app/api/__generated__/endpoints/admin/admin"; +import { + TabsLine, + TabsLineContent, + TabsLineList, + TabsLineTrigger, +} from "@/components/molecules/TabsLine/TabsLine"; + +interface ScheduleDetail { + schedule_id: string; + schedule_name: string; + graph_id: string; + graph_name: string; + graph_version: number; + user_id: string; + user_email: string | null; + cron: string; + timezone: string; + next_run_time: string; +} + +interface OrphanedScheduleDetail { + schedule_id: string; + schedule_name: string; + graph_id: string; + graph_name?: string; + graph_version: number; + user_id: string; + user_email?: string | null; + cron?: string; + timezone?: string; + orphan_reason: string; + error_detail: string | null; + next_run_time: string; +} + +interface CleanupResponseData { + success: boolean; + message: string; + deleted_count?: number; +} + +interface SchedulesTableProps { + onRefresh?: () => void; + diagnosticsData?: { + total_orphaned: number; + user_schedules: number; + }; +} + +export function SchedulesTable({ + onRefresh, + diagnosticsData, +}: SchedulesTableProps) { + const [activeTab, setActiveTab] = useState<"all" | "orphaned">("all"); + const [selectedIds, setSelectedIds] = useState>(new Set()); + const [showDeleteDialog, setShowDeleteDialog] = useState(false); + const [currentPage, setCurrentPage] = useState(1); + const [pageSize] = useState(10); + + // Fetch data based on active tab + const allSchedulesQuery = useGetV2ListAllUserSchedules( + { + limit: pageSize, + offset: (currentPage - 1) * pageSize, + }, + { query: { enabled: activeTab === "all" } }, + ); + + const orphanedSchedulesQuery = useGetV2ListOrphanedSchedules({ + query: { enabled: activeTab === "orphaned" }, + }); + + const activeQuery = + activeTab === "orphaned" ? orphanedSchedulesQuery : allSchedulesQuery; + + const { + data: schedulesResponse, + isLoading, + error: _error, + refetch, + } = activeQuery; + + const schedulesData = schedulesResponse?.data as + | { schedules: (ScheduleDetail | OrphanedScheduleDetail)[]; total: number } + | undefined; + const schedules = schedulesData?.schedules || []; + const total = schedulesData?.total || 0; + + // Cleanup mutation + const { mutateAsync: cleanupOrphanedSchedules, isPending: isDeleting } = + usePostV2CleanupOrphanedSchedules(); + + const handleSelectAll = (checked: boolean) => { + if (checked) { + setSelectedIds( + new Set( + schedules.map( + (s: ScheduleDetail | OrphanedScheduleDetail) => s.schedule_id, + ), + ), + ); + } else { + setSelectedIds(new Set()); + } + }; + + const handleSelectSchedule = (id: string, checked: boolean) => { + const newSelected = new Set(selectedIds); + if (checked) { + newSelected.add(id); + } else { + newSelected.delete(id); + } + setSelectedIds(newSelected); + }; + + const confirmDelete = () => { + setShowDeleteDialog(true); + }; + + const handleDelete = async () => { + setShowDeleteDialog(false); + + try { + const idsToDelete = + activeTab === "orphaned" && selectedIds.size === 0 + ? schedules.map( + (s: ScheduleDetail | OrphanedScheduleDetail) => s.schedule_id, + ) + : Array.from(selectedIds); + + const result = await cleanupOrphanedSchedules({ + data: { schedule_ids: idsToDelete }, + }); + + toast({ + title: "Success", + description: + (result.data as CleanupResponseData)?.message || + `Deleted ${(result.data as CleanupResponseData)?.deleted_count || 0} schedule(s)`, + }); + + setSelectedIds(new Set()); + await refetch(); + if (onRefresh) onRefresh(); + } catch (err: unknown) { + console.error("Error deleting schedules:", err); + toast({ + title: "Error", + description: + err instanceof Error ? err.message : "Failed to delete schedules", + variant: "destructive", + }); + } + }; + + const totalPages = Math.ceil(total / pageSize); + + return ( + <> + + setActiveTab(v as "all" | "orphaned")} + > + +
+ Schedules +
+ {activeTab === "orphaned" && schedules.length > 0 && ( + + )} + {selectedIds.size > 0 && ( + + )} + +
+
+ + + + All Schedules + {diagnosticsData && ` (${diagnosticsData.user_schedules})`} + + + Orphaned + {diagnosticsData && ` (${diagnosticsData.total_orphaned})`} + + +
+ + + + {isLoading && schedules.length === 0 ? ( +
+ +
+ ) : schedules.length === 0 ? ( +
+ No schedules found +
+ ) : ( + + + + + 0 + } + onCheckedChange={handleSelectAll} + /> + + Name + Graph + User + Cron + Next Run + {activeTab === "orphaned" && ( + Orphan Reason + )} + + + + {schedules.map( + (schedule: ScheduleDetail | OrphanedScheduleDetail) => { + const isOrphaned = activeTab === "orphaned"; + return ( + + + + handleSelectSchedule( + schedule.schedule_id, + checked as boolean, + ) + } + /> + + {schedule.schedule_name} + +
{schedule.graph_name || "Unknown"}
+
+ v{schedule.graph_version} +
+
+ +
+ {(schedule as ScheduleDetail).user_email || ( + Unknown + )} +
+
{ + navigator.clipboard.writeText( + schedule.user_id, + ); + toast({ + title: "Copied", + description: "User ID copied to clipboard", + }); + }} + title="Click to copy user ID" + > + {schedule.user_id.substring(0, 8)}... + +
+
+ + {schedule.cron ? ( + <> + + {schedule.cron} + +
+ {schedule.timezone} +
+ + ) : ( + N/A + )} +
+ + {schedule.next_run_time + ? new Date( + schedule.next_run_time, + ).toLocaleString() + : "Not scheduled"} + + {activeTab === "orphaned" && ( + + + {( + schedule as OrphanedScheduleDetail + ).orphan_reason?.replace(/_/g, " ") || + "unknown"} + + + )} +
+ ); + }, + )} +
+
+ )} + + {totalPages > 1 && activeTab === "all" && ( +
+
+ Showing {(currentPage - 1) * pageSize + 1} to{" "} + {Math.min(currentPage * pageSize, total)} of {total}{" "} + schedules +
+
+ +
+ Page {currentPage} of {totalPages} +
+ +
+
+ )} +
+
+
+
+ + + + + Confirm Delete Schedules + + {activeTab === "orphaned" && selectedIds.size === 0 ? ( + <> + Are you sure you want to delete ALL {total} orphaned + schedules? +
+
+ These schedules reference deleted graphs or graphs the user no + longer has access to. Deleting them is safe. + + ) : ( + <> + Are you sure you want to delete {selectedIds.size} selected + schedule(s)? +
+
+ This will permanently remove the schedules from the system. + + )} +
+
+ + + + +
+
+ + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/useDiagnosticsContent.ts b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/useDiagnosticsContent.ts new file mode 100644 index 0000000000..e2d5dbab85 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/components/useDiagnosticsContent.ts @@ -0,0 +1,63 @@ +import { + useGetV2GetExecutionDiagnostics, + useGetV2GetAgentDiagnostics, + useGetV2GetScheduleDiagnostics, +} from "@/app/api/__generated__/endpoints/admin/admin"; +import type { ExecutionDiagnosticsResponse } from "@/app/api/__generated__/models/executionDiagnosticsResponse"; +import type { AgentDiagnosticsResponse } from "@/app/api/__generated__/models/agentDiagnosticsResponse"; +import type { ScheduleHealthMetrics } from "@/app/api/__generated__/models/scheduleHealthMetrics"; + +export function useDiagnosticsContent() { + const { + data: executionResponse, + isLoading: isLoadingExecutions, + isError: isExecutionError, + error: executionError, + refetch: refetchExecutions, + } = useGetV2GetExecutionDiagnostics(); + + const { + data: agentResponse, + isLoading: isLoadingAgents, + isError: isAgentError, + error: agentError, + refetch: refetchAgents, + } = useGetV2GetAgentDiagnostics(); + + const { + data: scheduleResponse, + isLoading: isLoadingSchedules, + isError: isScheduleError, + error: scheduleError, + refetch: refetchSchedules, + } = useGetV2GetScheduleDiagnostics(); + + const isLoading = + isLoadingExecutions || isLoadingAgents || isLoadingSchedules; + const isError = isExecutionError || isAgentError || isScheduleError; + const error = executionError || agentError || scheduleError; + + const executionData = executionResponse?.data as + | ExecutionDiagnosticsResponse + | undefined; + const agentData = agentResponse?.data as AgentDiagnosticsResponse | undefined; + const scheduleData = scheduleResponse?.data as + | ScheduleHealthMetrics + | undefined; + + const refresh = () => { + refetchExecutions(); + refetchAgents(); + refetchSchedules(); + }; + + return { + executionData, + agentData, + scheduleData, + isLoading, + isError, + error, + refresh, + }; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/page.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/page.tsx new file mode 100644 index 0000000000..cbbf0065b0 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/admin/diagnostics/page.tsx @@ -0,0 +1,17 @@ +import { withRoleAccess } from "@/lib/withRoleAccess"; +import { DiagnosticsContent } from "./components/DiagnosticsContent"; + +function AdminDiagnostics() { + return ( +
+ +
+ ); +} + +export default async function AdminDiagnosticsPage() { + "use server"; + const withAdminAccess = await withRoleAccess(["admin"]); + const ProtectedAdminDiagnostics = await withAdminAccess(AdminDiagnostics); + return ; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/admin/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/admin/layout.tsx index c7483d55cd..13dd942b52 100644 --- a/autogpt_platform/frontend/src/app/(platform)/admin/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/admin/layout.tsx @@ -6,6 +6,7 @@ import { Gauge, Receipt, FileText, + Heartbeat, } from "@phosphor-icons/react/dist/ssr"; import { IconSliders } from "@/components/__legacy__/ui/icons"; @@ -23,6 +24,11 @@ const sidebarLinkGroups = [ href: "/admin/spending", icon: , }, + { + text: "System Diagnostics", + href: "/admin/diagnostics", + icon: , + }, { text: "User Impersonation", href: "/admin/impersonation", diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 9103d6f475..87fc8ccace 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -7,6 +7,768 @@ "version": "0.1" }, "paths": { + "/api/admin/diagnostics/agents": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Get Agent Diagnostics", + "description": "Get diagnostic information about agents.\n\nReturns:\n - agents_with_active_executions: Number of unique agents with running/queued executions\n - timestamp: Current timestamp", + "operationId": "getV2Get agent diagnostics", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AgentDiagnosticsResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Get Execution Diagnostics", + "description": "Get comprehensive diagnostic information about execution status.\n\nReturns all execution metrics including:\n- Current state (running, queued)\n- Orphaned executions (>24h old, likely not in executor)\n- Failure metrics (1h, 24h, rate)\n- Long-running detection (stuck >1h, >24h)\n- Stuck queued detection\n- Throughput metrics (completions/hour)\n- RabbitMQ queue depths", + "operationId": "getV2Get execution diagnostics", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ExecutionDiagnosticsResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/cleanup-all-orphaned": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Cleanup ALL Orphaned Executions", + "description": "Cleanup ALL orphaned executions (>24h old) by directly updating DB status.\nOperates on all executions, not just paginated results.\n\nReturns:\n Number of executions cleaned up and success message", + "operationId": "postV2Cleanup all orphaned executions", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/cleanup-all-stuck-queued": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Cleanup ALL Stuck Queued Executions", + "description": "Cleanup ALL stuck queued executions (QUEUED >1h) by updating DB status (admin only).\nOperates on entire dataset, not limited to pagination.\n\nReturns:\n Number of executions cleaned up and success message", + "operationId": "postV2Cleanup all stuck queued executions", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/cleanup-orphaned": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Cleanup Orphaned Executions", + "description": "Cleanup orphaned executions by directly updating DB status (admin only).\nFor executions in DB but not actually running in executor (old/stale records).\n\nArgs:\n request: Contains list of execution_ids to cleanup\n\nReturns:\n Number of executions cleaned up and success message", + "operationId": "postV2Cleanup orphaned executions", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/StopExecutionsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/failed": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Failed Executions", + "description": "Get detailed list of failed executions.\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n hours: Number of hours to look back (default 24)\n\nReturns:\n List of failed executions with error details", + "operationId": "getV2List failed executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + }, + { + "name": "hours", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 24, "title": "Hours" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FailedExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/executions/invalid": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Invalid Executions", + "description": "Get detailed list of executions in invalid states (READ-ONLY).\n\nInvalid states indicate data corruption and require manual investigation:\n- QUEUED but has startedAt (impossible - can't start while queued)\n- RUNNING but no startedAt (impossible - can't run without starting)\n\n⚠️ NO BULK ACTIONS PROVIDED - These need case-by-case investigation.\n\nEach invalid execution likely has a different root cause (crashes, race conditions,\nDB corruption). Investigate the execution history and logs to determine appropriate\naction (manual cleanup, status fix, or leave as-is if system recovered).\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n\nReturns:\n List of invalid state executions with details", + "operationId": "getV2List invalid executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunningExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/executions/long-running": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Long-Running Executions", + "description": "Get detailed list of long-running executions (RUNNING status >24h).\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n\nReturns:\n List of long-running executions with details", + "operationId": "getV2List long-running executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunningExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/executions/orphaned": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Orphaned Executions", + "description": "Get detailed list of orphaned executions (>24h old, likely not in executor).\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n\nReturns:\n List of orphaned executions with details", + "operationId": "getV2List orphaned executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunningExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/executions/requeue": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Requeue Stuck Execution", + "description": "Requeue a stuck QUEUED execution (admin only).\n\nUses add_graph_execution with existing graph_exec_id to requeue.\n\n⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.\n\nArgs:\n request: Contains execution_id to requeue\n\nReturns:\n Success status and message", + "operationId": "postV2Requeue stuck execution", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/StopExecutionRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RequeueExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/requeue-all-stuck": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Requeue ALL Stuck Queued Executions", + "description": "Requeue ALL stuck queued executions (QUEUED >1h) by publishing to RabbitMQ.\nOperates on all executions, not just paginated results.\n\nUses add_graph_execution with existing graph_exec_id to requeue.\n\n⚠️ WARNING: This will re-execute ALL stuck executions and may cost significant credits.\n\nReturns:\n Number of executions requeued and success message", + "operationId": "postV2Requeue all stuck queued executions", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RequeueExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/requeue-bulk": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Requeue Multiple Stuck Executions", + "description": "Requeue multiple stuck QUEUED executions (admin only).\n\nUses add_graph_execution with existing graph_exec_id to requeue.\n\n⚠️ WARNING: Only use for stuck executions. This will re-execute and may cost credits.\n\nArgs:\n request: Contains list of execution_ids to requeue\n\nReturns:\n Number of executions requeued and success message", + "operationId": "postV2Requeue multiple stuck executions", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/StopExecutionsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RequeueExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/running": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Running Executions", + "description": "Get detailed list of running and queued executions (recent, likely active).\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n\nReturns:\n List of running executions with details", + "operationId": "getV2List running executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunningExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/executions/stop": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Stop Single Execution", + "description": "Stop a single execution (admin only).\n\nUses robust stop_graph_execution which cascades to children and waits for termination.\n\nArgs:\n request: Contains execution_id to stop\n\nReturns:\n Success status and message", + "operationId": "postV2Stop single execution", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/StopExecutionRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/stop-all-long-running": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Stop ALL Long-Running Executions", + "description": "Stop ALL long-running executions (RUNNING >24h) by sending cancel signals (admin only).\nOperates on entire dataset, not limited to pagination.\n\nReturns:\n Number of executions stopped and success message", + "operationId": "postV2Stop all long-running executions", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/stop-bulk": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Stop Multiple Executions", + "description": "Stop multiple active executions (admin only).\n\nUses robust stop_graph_execution which cascades to children and waits for termination.\n\nArgs:\n request: Contains list of execution_ids to stop\n\nReturns:\n Number of executions stopped and success message", + "operationId": "postV2Stop multiple executions", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/StopExecutionsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/StopExecutionResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/executions/stuck-queued": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Stuck Queued Executions", + "description": "Get detailed list of stuck queued executions (QUEUED >1h, never started).\n\nArgs:\n limit: Maximum number of executions to return (default 100)\n offset: Number of executions to skip (default 0)\n\nReturns:\n List of stuck queued executions with details", + "operationId": "getV2List stuck queued executions", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunningExecutionsListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/schedules": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Get Schedule Diagnostics", + "description": "Get comprehensive diagnostic information about schedule health.\n\nReturns schedule metrics including:\n- Total schedules (user vs system)\n- Orphaned schedules by category\n- Upcoming executions", + "operationId": "getV2Get schedule diagnostics", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScheduleHealthMetrics" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/schedules/all": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List All User Schedules", + "description": "Get detailed list of all user schedules (excludes system monitoring jobs).\n\nArgs:\n limit: Maximum number of schedules to return (default 100)\n offset: Number of schedules to skip (default 0)\n\nReturns:\n List of schedules with details", + "operationId": "getV2List all user schedules", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 100, "title": "Limit" } + }, + { + "name": "offset", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 0, "title": "Offset" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SchedulesListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/admin/diagnostics/schedules/cleanup-orphaned": { + "post": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "Cleanup Orphaned Schedules", + "description": "Cleanup orphaned schedules by deleting from scheduler (admin only).\n\nArgs:\n request: Contains list of schedule_ids to delete\n\nReturns:\n Number of schedules deleted and success message", + "operationId": "postV2Cleanup orphaned schedules", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScheduleCleanupRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScheduleCleanupResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/admin/diagnostics/schedules/orphaned": { + "get": { + "tags": ["v2", "admin", "diagnostics", "admin"], + "summary": "List Orphaned Schedules", + "description": "Get detailed list of orphaned schedules with orphan reasons.\n\nReturns:\n List of orphaned schedules categorized by orphan type", + "operationId": "getV2List orphaned schedules", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OrphanedSchedulesListResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, "/api/admin/platform-costs/dashboard": { "get": { "tags": ["v2", "admin", "platform-cost", "admin"], @@ -8120,6 +8882,19 @@ "title": "AgentDetailsResponse", "description": "Response for get_details action." }, + "AgentDiagnosticsResponse": { + "properties": { + "agents_with_active_executions": { + "type": "integer", + "title": "Agents With Active Executions" + }, + "timestamp": { "type": "string", "title": "Timestamp" } + }, + "type": "object", + "required": ["agents_with_active_executions", "timestamp"], + "title": "AgentDiagnosticsResponse", + "description": "Response model for agent diagnostics" + }, "AgentExecutionStatus": { "type": "string", "enum": [ @@ -9915,6 +10690,94 @@ ], "title": "ExecutionAnalyticsResult" }, + "ExecutionDiagnosticsResponse": { + "properties": { + "running_executions": { + "type": "integer", + "title": "Running Executions" + }, + "queued_executions_db": { + "type": "integer", + "title": "Queued Executions Db" + }, + "queued_executions_rabbitmq": { + "type": "integer", + "title": "Queued Executions Rabbitmq" + }, + "cancel_queue_depth": { + "type": "integer", + "title": "Cancel Queue Depth" + }, + "orphaned_running": { + "type": "integer", + "title": "Orphaned Running" + }, + "orphaned_queued": { "type": "integer", "title": "Orphaned Queued" }, + "failed_count_1h": { "type": "integer", "title": "Failed Count 1H" }, + "failed_count_24h": { + "type": "integer", + "title": "Failed Count 24H" + }, + "failure_rate_24h": { "type": "number", "title": "Failure Rate 24H" }, + "stuck_running_24h": { + "type": "integer", + "title": "Stuck Running 24H" + }, + "stuck_running_1h": { + "type": "integer", + "title": "Stuck Running 1H" + }, + "oldest_running_hours": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Oldest Running Hours" + }, + "stuck_queued_1h": { "type": "integer", "title": "Stuck Queued 1H" }, + "queued_never_started": { + "type": "integer", + "title": "Queued Never Started" + }, + "invalid_queued_with_start": { + "type": "integer", + "title": "Invalid Queued With Start" + }, + "invalid_running_without_start": { + "type": "integer", + "title": "Invalid Running Without Start" + }, + "completed_1h": { "type": "integer", "title": "Completed 1H" }, + "completed_24h": { "type": "integer", "title": "Completed 24H" }, + "throughput_per_hour": { + "type": "number", + "title": "Throughput Per Hour" + }, + "timestamp": { "type": "string", "title": "Timestamp" } + }, + "type": "object", + "required": [ + "running_executions", + "queued_executions_db", + "queued_executions_rabbitmq", + "cancel_queue_depth", + "orphaned_running", + "orphaned_queued", + "failed_count_1h", + "failed_count_24h", + "failure_rate_24h", + "stuck_running_24h", + "stuck_running_1h", + "oldest_running_hours", + "stuck_queued_1h", + "queued_never_started", + "invalid_queued_with_start", + "invalid_running_without_start", + "completed_1h", + "completed_24h", + "throughput_per_hour", + "timestamp" + ], + "title": "ExecutionDiagnosticsResponse", + "description": "Response model for execution diagnostics" + }, "ExecutionOptions": { "properties": { "manual": { "type": "boolean", "title": "Manual", "default": true }, @@ -10004,6 +10867,73 @@ "title": "ExecutionStartedResponse", "description": "Response for run/schedule actions." }, + "FailedExecutionDetail": { + "properties": { + "execution_id": { "type": "string", "title": "Execution Id" }, + "graph_id": { "type": "string", "title": "Graph Id" }, + "graph_name": { "type": "string", "title": "Graph Name" }, + "graph_version": { "type": "integer", "title": "Graph Version" }, + "user_id": { "type": "string", "title": "User Id" }, + "user_email": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "User Email" + }, + "status": { "type": "string", "title": "Status" }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "started_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Started At" + }, + "failed_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Failed At" + }, + "error_message": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error Message" + } + }, + "type": "object", + "required": [ + "execution_id", + "graph_id", + "graph_name", + "graph_version", + "user_id", + "user_email", + "status", + "created_at", + "started_at", + "failed_at", + "error_message" + ], + "title": "FailedExecutionDetail", + "description": "Details about a failed execution for admin view" + }, + "FailedExecutionsListResponse": { + "properties": { + "executions": { + "items": { "$ref": "#/components/schemas/FailedExecutionDetail" }, + "type": "array", + "title": "Executions" + }, + "total": { "type": "integer", "title": "Total" } + }, + "type": "object", + "required": ["executions", "total"], + "title": "FailedExecutionsListResponse", + "description": "Response model for list of failed executions" + }, "FolderCreateRequest": { "properties": { "name": { @@ -12226,6 +13156,48 @@ ], "title": "OnboardingStep" }, + "OrphanedScheduleDetail": { + "properties": { + "schedule_id": { "type": "string", "title": "Schedule Id" }, + "schedule_name": { "type": "string", "title": "Schedule Name" }, + "graph_id": { "type": "string", "title": "Graph Id" }, + "graph_version": { "type": "integer", "title": "Graph Version" }, + "user_id": { "type": "string", "title": "User Id" }, + "orphan_reason": { "type": "string", "title": "Orphan Reason" }, + "error_detail": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error Detail" + }, + "next_run_time": { "type": "string", "title": "Next Run Time" } + }, + "type": "object", + "required": [ + "schedule_id", + "schedule_name", + "graph_id", + "graph_version", + "user_id", + "orphan_reason", + "error_detail", + "next_run_time" + ], + "title": "OrphanedScheduleDetail", + "description": "Details about an orphaned schedule" + }, + "OrphanedSchedulesListResponse": { + "properties": { + "schedules": { + "items": { "$ref": "#/components/schemas/OrphanedScheduleDetail" }, + "type": "array", + "title": "Schedules" + }, + "total": { "type": "integer", "title": "Total" } + }, + "type": "object", + "required": ["schedules", "total"], + "title": "OrphanedSchedulesListResponse", + "description": "Response model for list of orphaned schedules" + }, "Pagination": { "properties": { "total_items": { @@ -13083,6 +14055,21 @@ "required": ["credit_amount"], "title": "RequestTopUp" }, + "RequeueExecutionResponse": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "requeued_count": { + "type": "integer", + "title": "Requeued Count", + "default": 0 + }, + "message": { "type": "string", "title": "Message" } + }, + "type": "object", + "required": ["success", "message"], + "title": "RequeueExecutionResponse", + "description": "Response model for requeue execution operations" + }, "ResponseType": { "type": "string", "enum": [ @@ -13247,6 +14234,92 @@ "required": ["store_listing_version_id", "is_approved", "comments"], "title": "ReviewSubmissionRequest" }, + "RunningExecutionDetail": { + "properties": { + "execution_id": { "type": "string", "title": "Execution Id" }, + "graph_id": { "type": "string", "title": "Graph Id" }, + "graph_name": { "type": "string", "title": "Graph Name" }, + "graph_version": { "type": "integer", "title": "Graph Version" }, + "user_id": { "type": "string", "title": "User Id" }, + "user_email": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "User Email" + }, + "status": { "type": "string", "title": "Status" }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "started_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Started At" + }, + "queue_status": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Queue Status" + } + }, + "type": "object", + "required": [ + "execution_id", + "graph_id", + "graph_name", + "graph_version", + "user_id", + "user_email", + "status", + "created_at", + "started_at" + ], + "title": "RunningExecutionDetail", + "description": "Details about a running execution for admin view" + }, + "RunningExecutionsListResponse": { + "properties": { + "executions": { + "items": { "$ref": "#/components/schemas/RunningExecutionDetail" }, + "type": "array", + "title": "Executions" + }, + "total": { "type": "integer", "title": "Total" } + }, + "type": "object", + "required": ["executions", "total"], + "title": "RunningExecutionsListResponse", + "description": "Response model for list of running executions" + }, + "ScheduleCleanupRequest": { + "properties": { + "schedule_ids": { + "items": { "type": "string" }, + "type": "array", + "title": "Schedule Ids" + } + }, + "type": "object", + "required": ["schedule_ids"], + "title": "ScheduleCleanupRequest", + "description": "Request model for cleaning up schedules" + }, + "ScheduleCleanupResponse": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "deleted_count": { + "type": "integer", + "title": "Deleted Count", + "default": 0 + }, + "message": { "type": "string", "title": "Message" } + }, + "type": "object", + "required": ["success", "message"], + "title": "ScheduleCleanupResponse", + "description": "Response model for schedule cleanup operations" + }, "ScheduleCreationRequest": { "properties": { "graph_version": { @@ -13277,6 +14350,121 @@ "required": ["name", "cron", "inputs"], "title": "ScheduleCreationRequest" }, + "ScheduleDetail": { + "properties": { + "schedule_id": { "type": "string", "title": "Schedule Id" }, + "schedule_name": { "type": "string", "title": "Schedule Name" }, + "graph_id": { "type": "string", "title": "Graph Id" }, + "graph_name": { "type": "string", "title": "Graph Name" }, + "graph_version": { "type": "integer", "title": "Graph Version" }, + "user_id": { "type": "string", "title": "User Id" }, + "user_email": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "User Email" + }, + "cron": { "type": "string", "title": "Cron" }, + "timezone": { "type": "string", "title": "Timezone" }, + "next_run_time": { "type": "string", "title": "Next Run Time" }, + "created_at": { + "anyOf": [ + { "type": "string", "format": "date-time" }, + { "type": "null" } + ], + "title": "Created At" + } + }, + "type": "object", + "required": [ + "schedule_id", + "schedule_name", + "graph_id", + "graph_name", + "graph_version", + "user_id", + "user_email", + "cron", + "timezone", + "next_run_time" + ], + "title": "ScheduleDetail", + "description": "Details about a schedule for admin view" + }, + "ScheduleHealthMetrics": { + "properties": { + "total_schedules": { "type": "integer", "title": "Total Schedules" }, + "user_schedules": { "type": "integer", "title": "User Schedules" }, + "system_schedules": { + "type": "integer", + "title": "System Schedules" + }, + "orphaned_deleted_graph": { + "type": "integer", + "title": "Orphaned Deleted Graph" + }, + "orphaned_no_library_access": { + "type": "integer", + "title": "Orphaned No Library Access" + }, + "orphaned_invalid_credentials": { + "type": "integer", + "title": "Orphaned Invalid Credentials" + }, + "orphaned_validation_failed": { + "type": "integer", + "title": "Orphaned Validation Failed" + }, + "total_orphaned": { "type": "integer", "title": "Total Orphaned" }, + "schedules_next_hour": { + "type": "integer", + "title": "Schedules Next Hour" + }, + "schedules_next_24h": { + "type": "integer", + "title": "Schedules Next 24H" + }, + "total_runs_next_hour": { + "type": "integer", + "title": "Total Runs Next Hour" + }, + "total_runs_next_24h": { + "type": "integer", + "title": "Total Runs Next 24H" + }, + "timestamp": { "type": "string", "title": "Timestamp" } + }, + "type": "object", + "required": [ + "total_schedules", + "user_schedules", + "system_schedules", + "orphaned_deleted_graph", + "orphaned_no_library_access", + "orphaned_invalid_credentials", + "orphaned_validation_failed", + "total_orphaned", + "schedules_next_hour", + "schedules_next_24h", + "total_runs_next_hour", + "total_runs_next_24h", + "timestamp" + ], + "title": "ScheduleHealthMetrics", + "description": "Summary of schedule health diagnostics" + }, + "SchedulesListResponse": { + "properties": { + "schedules": { + "items": { "$ref": "#/components/schemas/ScheduleDetail" }, + "type": "array", + "title": "Schedules" + }, + "total": { "type": "integer", "title": "Total" } + }, + "type": "object", + "required": ["schedules", "total"], + "title": "SchedulesListResponse", + "description": "Response model for list of schedules" + }, "SearchEntry": { "properties": { "search_query": { @@ -13588,6 +14776,43 @@ "type": "object", "title": "Stats" }, + "StopExecutionRequest": { + "properties": { + "execution_id": { "type": "string", "title": "Execution Id" } + }, + "type": "object", + "required": ["execution_id"], + "title": "StopExecutionRequest", + "description": "Request model for stopping a single execution" + }, + "StopExecutionResponse": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "stopped_count": { + "type": "integer", + "title": "Stopped Count", + "default": 0 + }, + "message": { "type": "string", "title": "Message" } + }, + "type": "object", + "required": ["success", "message"], + "title": "StopExecutionResponse", + "description": "Response model for stop execution operations" + }, + "StopExecutionsRequest": { + "properties": { + "execution_ids": { + "items": { "type": "string" }, + "type": "array", + "title": "Execution Ids" + } + }, + "type": "object", + "required": ["execution_ids"], + "title": "StopExecutionsRequest", + "description": "Request model for stopping multiple executions" + }, "StorageUsageResponse": { "properties": { "used_bytes": { "type": "integer", "title": "Used Bytes" }, From 59273fe6a09ae1d9f8ff6a9789bbc1396ec96165 Mon Sep 17 00:00:00 2001 From: Nicholas Tindle Date: Tue, 21 Apr 2026 10:29:19 -0500 Subject: [PATCH 09/41] fix(frontend): forward sentry-trace and baggage across API proxy (#12835) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why / What / How **Why:** Every request that went through Next's rewrite proxy broke distributed tracing. The browser Sentry SDK emitted `sentry-trace` and `baggage`, but `createRequestHeaders` only forwarded impersonation + API key, so the backend started a disconnected transaction. The frontend → backend lineage never appeared in Sentry. Same gap on direct-from-browser requests: the custom mutator never attached the trace headers itself, so even non-proxied paths lost the link. **What:** - **Server side:** forward `sentry-trace` and `baggage` from `originalRequest.headers` alongside the existing impersonation/API key forwarding. - **Client side:** the custom mutator pulls trace data via `Sentry.getTraceData()` and attaches it to outgoing headers when running on the client. **How:** Inline additions — no new observability module, no new dependencies beyond `@sentry/nextjs` which the frontend already uses for Sentry init. ### Changes 🏗️ - `src/lib/autogpt-server-api/helpers.ts` — forward `sentry-trace` + `baggage` in `createRequestHeaders`. - `src/app/api/mutators/custom-mutator.ts` — import `@sentry/nextjs`, attach `Sentry.getTraceData()` on client-side requests. - `src/app/api/mutators/__tests__/custom-mutator.test.ts` — three new tests: trace-data present, trace-data empty, server-side no-op. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [ ] I have tested my changes according to the test plan: - [x] `pnpm vitest run src/app/api/mutators/__tests__/custom-mutator.test.ts` passes (6/6 locally) - [x] `pnpm format && pnpm lint` clean - [x] `pnpm types` clean for touched files (pre-existing unrelated type errors on dev are untouched) - [ ] In a local session with Sentry enabled, a `/copilot` chat turn produces a distributed trace that spans frontend transaction → backend transaction (single trace ID in Sentry) --- > [!NOTE] > **Low Risk** > Low risk: header-only changes to request construction for observability, with added tests; primary risk is unintended header propagation affecting upstream/proxy behavior. > > **Overview** > Restores **Sentry distributed tracing continuity** for frontend→backend calls by propagating `sentry-trace`/`baggage` headers. > > On the client, `customMutator` now reads `Sentry.getTraceData()` and attaches string trace headers to outgoing requests (guarded for server-side and older Sentry builds). On the server/proxy path, `createRequestHeaders` now forwards `sentry-trace` and `baggage` from the incoming `originalRequest` alongside existing impersonation/API-key forwarding, with new unit tests covering these cases. > > Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 0f6946b7764b2cacc2f2d947fbcfeb75a691ca1d. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot). --------- Co-authored-by: Claude Opus 4.7 (1M context) --- .../mutators/__tests__/custom-mutator.test.ts | 95 ++++++++++ .../src/app/api/mutators/custom-mutator.ts | 7 + .../lib/autogpt-server-api/helpers.test.ts | 171 ++++++++++++++++++ .../src/lib/autogpt-server-api/helpers.ts | 9 + 4 files changed, 282 insertions(+) create mode 100644 autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.test.ts diff --git a/autogpt_platform/frontend/src/app/api/mutators/__tests__/custom-mutator.test.ts b/autogpt_platform/frontend/src/app/api/mutators/__tests__/custom-mutator.test.ts index 7debeb3f5a..89b17866c3 100644 --- a/autogpt_platform/frontend/src/app/api/mutators/__tests__/custom-mutator.test.ts +++ b/autogpt_platform/frontend/src/app/api/mutators/__tests__/custom-mutator.test.ts @@ -26,13 +26,19 @@ vi.mock("@/lib/autogpt-server-api/helpers", () => ({ getServerAuthToken: vi.fn(), })); +vi.mock("@sentry/nextjs", () => ({ + getTraceData: vi.fn(() => ({})), +})); + import { customMutator } from "../custom-mutator"; import { getSystemHeaders } from "@/lib/impersonation"; import { environment } from "@/services/environment"; import { IMPERSONATION_HEADER_NAME } from "@/lib/constants"; +import * as Sentry from "@sentry/nextjs"; const mockIsClientSide = vi.mocked(environment.isClientSide); const mockGetSystemHeaders = vi.mocked(getSystemHeaders); +const mockGetTraceData = vi.mocked(Sentry.getTraceData); describe("customMutator — impersonation header", () => { beforeEach(() => { @@ -88,3 +94,92 @@ describe("customMutator — impersonation header", () => { expect(headers["X-Custom-Header"]).toBe("custom-value"); }); }); + +describe("customMutator — Sentry trace propagation", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockIsClientSide.mockReturnValue(true); + mockGetSystemHeaders.mockReturnValue({}); + mockGetTraceData.mockReturnValue({}); + vi.stubGlobal( + "fetch", + vi.fn().mockResolvedValue({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "application/json" }), + json: () => Promise.resolve({}), + }), + ); + }); + + it("attaches sentry-trace and baggage headers from Sentry trace data on client-side", async () => { + mockGetTraceData.mockReturnValue({ + "sentry-trace": "0123456789abcdef0123456789abcdef-0123456789abcdef-1", + baggage: "sentry-environment=local,sentry-public_key=abc", + }); + + await customMutator("/test", { method: "GET" }); + + const fetchCall = vi.mocked(fetch).mock.calls[0]; + const headers = fetchCall[1]?.headers as Record; + expect(headers["sentry-trace"]).toBe( + "0123456789abcdef0123456789abcdef-0123456789abcdef-1", + ); + expect(headers["baggage"]).toBe( + "sentry-environment=local,sentry-public_key=abc", + ); + }); + + it("omits sentry-trace headers when Sentry has no active trace", async () => { + mockGetTraceData.mockReturnValue({}); + + await customMutator("/test", { method: "GET" }); + + const fetchCall = vi.mocked(fetch).mock.calls[0]; + const headers = fetchCall[1]?.headers as Record; + expect(headers["sentry-trace"]).toBeUndefined(); + expect(headers["baggage"]).toBeUndefined(); + }); + + it("does not attach Sentry trace headers on server-side", async () => { + mockIsClientSide.mockReturnValue(false); + mockGetTraceData.mockReturnValue({ + "sentry-trace": "should-not-appear", + }); + + await customMutator("/test", { method: "GET" }); + + expect(mockGetTraceData).not.toHaveBeenCalled(); + }); + + it("skips non-string values returned by Sentry.getTraceData", async () => { + // Simulate a non-string slipping into the trace-data object + mockGetTraceData.mockReturnValue({ + "sentry-trace": "real-trace", + "sentry-sampled": 1, + } as unknown as ReturnType); + + await customMutator("/test", { method: "GET" }); + + const fetchCall = vi.mocked(fetch).mock.calls[0]; + const headers = fetchCall[1]?.headers as Record; + expect(headers["sentry-trace"]).toBe("real-trace"); + expect(headers["sentry-sampled"]).toBeUndefined(); + }); + + it("falls back to an empty object when Sentry.getTraceData is undefined", async () => { + // Simulate an older @sentry/nextjs build where getTraceData isn't exported + (Sentry as { getTraceData?: unknown }).getTraceData = + undefined as unknown as typeof Sentry.getTraceData; + + await customMutator("/test", { method: "GET" }); + + const fetchCall = vi.mocked(fetch).mock.calls[0]; + const headers = fetchCall[1]?.headers as Record; + expect(headers["sentry-trace"]).toBeUndefined(); + expect(headers["baggage"]).toBeUndefined(); + + // Restore for subsequent tests + (Sentry as { getTraceData?: unknown }).getTraceData = mockGetTraceData; + }); +}); diff --git a/autogpt_platform/frontend/src/app/api/mutators/custom-mutator.ts b/autogpt_platform/frontend/src/app/api/mutators/custom-mutator.ts index 05b49f10e7..019e911fbf 100644 --- a/autogpt_platform/frontend/src/app/api/mutators/custom-mutator.ts +++ b/autogpt_platform/frontend/src/app/api/mutators/custom-mutator.ts @@ -3,6 +3,7 @@ import { createRequestHeaders, getServerAuthToken, } from "@/lib/autogpt-server-api/helpers"; +import * as Sentry from "@sentry/nextjs"; import { getSystemHeaders } from "@/lib/impersonation"; import { environment } from "@/services/environment"; @@ -53,6 +54,12 @@ export const customMutator = async < }; if (environment.isClientSide()) { + const traceData = Sentry.getTraceData?.() ?? {}; + for (const [key, value] of Object.entries(traceData)) { + if (typeof value === "string") { + headers[key] = value; + } + } Object.assign(headers, getSystemHeaders()); } diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.test.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.test.ts new file mode 100644 index 0000000000..690a6141a5 --- /dev/null +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, it, vi } from "vitest"; + +vi.mock("@/lib/supabase/server/getServerSupabase", () => ({ + getServerSupabase: vi.fn(), +})); + +vi.mock("@/services/environment", () => ({ + environment: { + isServerSide: vi.fn(() => true), + isClientSide: vi.fn(() => false), + getAGPTServerApiUrl: vi.fn(() => "http://localhost:8006/api"), + }, +})); + +import { createRequestHeaders } from "./helpers"; +import { + API_KEY_HEADER_NAME, + IMPERSONATION_HEADER_NAME, +} from "@/lib/constants"; + +function makeRequest(headers: Record): Request { + return new Request("http://example.com/test", { headers }); +} + +describe("createRequestHeaders — basics", () => { + it("adds Content-Type when hasRequestBody is true", () => { + const headers = createRequestHeaders("token-abc", true); + expect(headers["Content-Type"]).toBe("application/json"); + }); + + it("omits Content-Type when hasRequestBody is false", () => { + const headers = createRequestHeaders("token-abc", false); + expect(headers["Content-Type"]).toBeUndefined(); + }); + + it("uses the provided contentType override", () => { + const headers = createRequestHeaders( + "token-abc", + true, + "application/x-www-form-urlencoded", + ); + expect(headers["Content-Type"]).toBe("application/x-www-form-urlencoded"); + }); + + it("adds Authorization header when token is a real value", () => { + const headers = createRequestHeaders("token-abc", false); + expect(headers["Authorization"]).toBe("Bearer token-abc"); + }); + + it("omits Authorization when token is the 'no-token-found' sentinel", () => { + const headers = createRequestHeaders("no-token-found", false); + expect(headers["Authorization"]).toBeUndefined(); + }); + + it("omits Authorization when token is empty", () => { + const headers = createRequestHeaders("", false); + expect(headers["Authorization"]).toBeUndefined(); + }); +}); + +describe("createRequestHeaders — Sentry trace forwarding", () => { + it("forwards sentry-trace and baggage headers when present on originalRequest", () => { + const request = makeRequest({ + "sentry-trace": "0123456789abcdef0123456789abcdef-0123456789abcdef-1", + baggage: "sentry-environment=local,sentry-public_key=abc", + }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers["sentry-trace"]).toBe( + "0123456789abcdef0123456789abcdef-0123456789abcdef-1", + ); + expect(headers["baggage"]).toBe( + "sentry-environment=local,sentry-public_key=abc", + ); + }); + + it("forwards only sentry-trace when baggage is absent", () => { + const request = makeRequest({ + "sentry-trace": "trace-id-only", + }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers["sentry-trace"]).toBe("trace-id-only"); + expect(headers["baggage"]).toBeUndefined(); + }); + + it("forwards only baggage when sentry-trace is absent", () => { + const request = makeRequest({ + baggage: "sentry-environment=prod", + }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers["sentry-trace"]).toBeUndefined(); + expect(headers["baggage"]).toBe("sentry-environment=prod"); + }); + + it("does not forward sentry headers when originalRequest has none", () => { + const request = makeRequest({ "X-Other-Header": "something" }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers["sentry-trace"]).toBeUndefined(); + expect(headers["baggage"]).toBeUndefined(); + }); + + it("does not attempt to forward sentry headers when originalRequest is omitted", () => { + const headers = createRequestHeaders("token-abc", false); + + expect(headers["sentry-trace"]).toBeUndefined(); + expect(headers["baggage"]).toBeUndefined(); + }); +}); + +describe("createRequestHeaders — impersonation and API-key forwarding", () => { + it("forwards the impersonation header alongside sentry headers", () => { + const request = makeRequest({ + [IMPERSONATION_HEADER_NAME]: "impersonated-user-xyz", + "sentry-trace": "trace-id", + }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers[IMPERSONATION_HEADER_NAME]).toBe("impersonated-user-xyz"); + expect(headers["sentry-trace"]).toBe("trace-id"); + }); + + it("forwards the API key header alongside sentry headers", () => { + const request = makeRequest({ + [API_KEY_HEADER_NAME]: "api-key-value", + baggage: "sentry-environment=local", + }); + + const headers = createRequestHeaders( + "token-abc", + false, + undefined, + request, + ); + + expect(headers[API_KEY_HEADER_NAME]).toBe("api-key-value"); + expect(headers["baggage"]).toBe("sentry-environment=local"); + }); +}); diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts index 4cb24df77d..7e6bc0f458 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/helpers.ts @@ -163,6 +163,15 @@ export function createRequestHeaders( if (apiKeyHeader) { headers[API_KEY_HEADER_NAME] = apiKeyHeader; } + + // Forward Sentry distributed-tracing headers so the backend transaction + // continues the browser span instead of starting a disconnected trace. + for (const name of ["sentry-trace", "baggage"] as const) { + const value = originalRequest.headers.get(name); + if (value) { + headers[name] = value; + } + } } return headers; From a098f01bd290c0c6ed56cb878bacb0d0266a46b4 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 21 Apr 2026 22:47:23 +0700 Subject: [PATCH 10/41] feat(builder): AI chat panel for the flow builder (#12699) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Why The flow builder had no AI assistance. Users had to switch to a separate Copilot session to ask about or modify the agent they were looking at, and that session had no context on the graph — so the LLM guessed, or the user had to describe the graph by hand. ### What An AI chat panel anchored to the `/build` page. Opens with a chat-circle button (bottom-right), binds to the currently-opened agent, and offers **only** two tools: `edit_agent` and `run_agent`. Per-agent session is persisted server-side, so a refresh resumes the same conversation. Gated behind `Flag.BUILDER_CHAT_PANEL` (default off; `NEXT_PUBLIC_FORCE_FLAG_BUILDER_CHAT_PANEL=true` to enable locally). ### How **Frontend — new**: - `(platform)/build/components/BuilderChatPanel/` — panel shell + `useBuilderChatPanel.ts` coordinator. Renders the shared Copilot `ChatMessagesContainer` + `ChatInput` (thought rendering, pulse chips, fast-mode toggle — all reused, no parallel chat stack). Auto-creates a blank agent when opened with no `flowID`. Listens for `edit_agent` / `run_agent` tool outputs and wires them to the builder in-place: edit → `flowVersion` URL param + canvas refetch; run → `flowExecutionID` URL param → builder's existing execution-follow UI opens. **Frontend — touched (minimal)**: - `copilot/components/CopilotChatActionsProvider` — new `chatSurface: "copilot" | "builder"` flag so cards can suppress "Open in library" / "Open in builder" / "View Execution" buttons when the chat is the builder panel (you're already there). - `copilot/tools/RunAgent/components/ExecutionStartedCard` — title is now status-aware (`QUEUED → "Execution started"`, `COMPLETED → "Execution completed"`, `FAILED → "Execution failed"`, etc.). - `build/components/FlowEditor/Flow/Flow.tsx` — mount the panel behind the feature flag. **Backend — new**: - `copilot/builder_context.py` — the builder-session logic module. Holds the tool whitelist (`edit_agent`, `run_agent`), the permissions resolver, the session-long system-prompt suffix (graph id/name + full agent-building guide — cacheable across turns), and the per-turn `` prefix (live version + compact nodes/links snapshot). - `copilot/builder_context_test.py` — covers both builders, ownership forwarding, and cap behavior. **Backend — touched**: - `api/features/chat/routes.py` — `CreateSessionRequest` gains `builder_graph_id`. When set, the endpoint routes through `get_or_create_builder_session` (keyed on `user_id`+`graph_id`, with a graph-ownership check). No new route; the former `/sessions/builder` is folded into `POST /sessions`. - `copilot/model.py` — `ChatSessionMetadata.builder_graph_id`; `get_or_create_builder_session` helper. - `data/graph.py` — `GraphSettings.builder_chat_session_id` (new typed field; stores the builder-chat session pointer per library agent). - `api/features/library/db.py` — `update_library_agent_version_and_settings` preserves `builder_chat_session_id` across graph-version bumps. - `copilot/tools/edit_agent.py`, `run_agent.py` — builder-bound guard: default missing `agent_id` to the bound graph, reject any other id. `run_agent` additionally inlines `node_executions` into dry-run responses so the LLM can inspect per-node status in the same turn instead of a follow-up `view_agent_output`. `wait_for_result` docs now explain the two dispatch modes. - `copilot/tools/helpers.py::require_guide_read` — bypassed for builder-bound sessions (the guide is already in the system-prompt suffix). - `copilot/tools/agent_generator/pipeline.py` + `tools/models.py` — `AgentSavedResponse.graph_version` so the frontend can flip `flowVersion` to the newly-saved version. - `copilot/baseline/service.py` + `sdk/service.py` — inject the builder context suffix into the system prompt and the per-turn prefix into the current user message. - `blocks/_base.py` — `validate_data(..., exclude_fields=)` so dry-run can bypass credential required-checks for blocks that need creds in normal mode (OrchestratorBlock). `blocks/perplexity.py` override signature matches. - `executor/simulator.py` — OrchestratorBlock dry-run iteration cap `1 → min(original, 10)` so multi-role patterns (Advocate/Critic) actually close the loop; `manager.py` synthesizes placeholder creds in dry-run so the block's schema validation passes. ### Session lookup The builder-chat session pointer lives on `LibraryAgent.settings.builder_chat_session_id` (typed via `GraphSettings`). `get_or_create_builder_session` reads/writes it through `library_db().get_library_agent_by_graph_id` + `update_library_agent(settings=...)` — no raw SQL or JSON-path filter. Ownership is enforced by the library-agent query's `userId` filter. The per-session builder binding still lives on `ChatSession.metadata.builder_graph_id` (used by `edit_agent`/`run_agent` guards and the system-prompt injection). ### Scope footnotes - Feature flag defaults **false**. Rollout gate lives in LaunchDarkly. - No schema migration required: `builder_chat_session_id` slots into the existing `LibraryAgent.settings` JSON column via the typed `GraphSettings` model. - Commits that address review / CI cycles are interleaved with feature commits — see the commit log for the per-change rationale. ### Test plan - [x] `pnpm test:unit` + backend `poetry run test` for new and touched modules - [x] Agent-browser pass: panel toggle / auto-create / real-time edit re-render / real-time exec URL subscribe / queue-while-streaming / cross-graph reset / hard-refresh session persist - [x] Codecov patch ≥ 80% on diff --------- Co-authored-by: Claude Opus 4.6 (1M context) --- .gitignore | 1 + .../backend/api/features/chat/routes.py | 58 +- .../backend/api/features/chat/routes_test.py | 738 ++++++ .../backend/api/features/library/db.py | 1 + .../backend/backend/blocks/_base.py | 39 +- .../backend/backend/blocks/perplexity.py | 15 +- .../backend/copilot/baseline/service.py | 37 +- .../copilot/baseline/service_unit_test.py | 75 + .../backend/copilot/builder_context.py | 217 ++ .../backend/copilot/builder_context_test.py | 329 +++ .../backend/backend/copilot/model.py | 92 +- .../backend/backend/copilot/model_test.py | 145 ++ .../copilot/sdk/agent_generation_guide.md | 12 +- .../backend/backend/copilot/sdk/service.py | 46 + .../backend/backend/copilot/service.py | 7 + .../copilot/tools/agent_generator/pipeline.py | 5 +- .../copilot/tools/agent_guide_gate_test.py | 32 +- .../copilot/tools/create_agent_test.py | 3 +- .../copilot/tools/customize_agent_test.py | 3 +- .../backend/copilot/tools/edit_agent.py | 18 + .../backend/copilot/tools/edit_agent_test.py | 93 + .../backend/backend/copilot/tools/helpers.py | 6 + .../backend/backend/copilot/tools/models.py | 1 + .../backend/copilot/tools/run_agent.py | 117 +- .../backend/copilot/tools/test_dry_run.py | 15 +- .../backend/copilot/tools/tool_schema_test.py | 10 +- .../backend/backend/data/db_manager.py | 3 + .../backend/backend/data/graph.py | 4 +- .../backend/data/platform_cost_test.py | 6 + .../backend/backend/executor/simulator.py | 13 +- .../backend/executor/simulator_test.py | 3 +- .../backend/snapshots/lib_agts_search | 6 +- .../BuilderChatPanel/BuilderChatPanel.tsx | 487 +--- .../__tests__/BuilderChatPanel.test.tsx | 795 +----- .../__tests__/helpers.test.ts | 105 - .../__tests__/useBuilderChatPanel.test.ts | 2303 ++++++----------- .../components/PanelHeader.tsx | 53 + .../components/BuilderChatPanel/helpers.ts | 252 -- .../BuilderChatPanel/useBuilderChatPanel.ts | 948 ++++--- .../build/components/FlowEditor/Flow/Flow.tsx | 10 +- .../AgentSavedCard/AgentSavedCard.tsx | 56 +- .../CopilotChatActionsProvider.tsx | 15 +- .../useCopilotChatActions.ts | 14 + .../ToolAccordion/AccordionContent.tsx | 2 +- .../ToolErrorCard/ToolErrorCard.tsx | 20 +- .../copilot/tools/FindAgents/FindAgents.tsx | 8 +- .../copilot/tools/RunAgent/RunAgent.tsx | 4 + .../ExecutionStartedCard.tsx | 51 +- .../titleForStatus.test.ts | 32 + .../frontend/src/app/api/openapi.json | 41 +- .../__tests__/envFlagOverride.test.ts | 24 + 51 files changed, 3696 insertions(+), 3674 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/builder_context.py create mode 100644 autogpt_platform/backend/backend/copilot/builder_context_test.py create mode 100644 autogpt_platform/backend/backend/copilot/tools/edit_agent_test.py delete mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/__tests__/helpers.test.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/components/PanelHeader.tsx delete mode 100644 autogpt_platform/frontend/src/app/(platform)/build/components/BuilderChatPanel/helpers.ts create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/tools/RunAgent/components/ExecutionStartedCard/titleForStatus.test.ts diff --git a/.gitignore b/.gitignore index 97d6b18a76..53df57dc70 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,4 @@ test.db # Implementation plans (generated by AI agents) plans/ .claude/worktrees/ +test-results/ diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 6ef15f0999..ca7e4355f6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from backend.copilot import service as chat_service from backend.copilot import stream_registry +from backend.copilot.builder_context import resolve_session_permissions from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn @@ -24,6 +25,7 @@ from backend.copilot.model import ( create_chat_session, delete_chat_session, get_chat_session, + get_or_create_builder_session, get_user_sessions, update_session_title, ) @@ -133,7 +135,7 @@ def _strip_injected_context(message: dict) -> dict: class StreamChatRequest(BaseModel): """Request model for streaming chat with optional context.""" - message: str + message: str = Field(max_length=64_000) is_user_message: bool = True context: dict[str, str] | None = None # {url: str, content: str} file_ids: list[str] | None = Field( @@ -165,15 +167,31 @@ class PeekPendingMessagesResponse(BaseModel): class CreateSessionRequest(BaseModel): - """Request model for creating a new chat session. + """Request model for creating (or get-or-creating) a chat session. + + Two modes, selected by the body: + + - Default: create a fresh session. ``dry_run`` is a **top-level** + field — do not nest it inside ``metadata``. + - Builder-bound: when ``builder_graph_id`` is set, the endpoint + switches to **get-or-create** keyed on + ``(user_id, builder_graph_id)``. The builder panel calls this on + mount so the chat persists across refreshes. Graph ownership is + validated inside :func:`get_or_create_builder_session`. Write-side + scope is enforced per-tool (``edit_agent`` / ``run_agent`` reject + any ``agent_id`` other than the bound graph) and a small blacklist + hides tools that conflict with the panel's scope + (``create_agent`` / ``customize_agent`` / ``get_agent_building_guide`` + — see :data:`BUILDER_BLOCKED_TOOLS`). Read-side lookups + (``find_block``, ``find_agent``, ``search_docs``, …) stay open. - ``dry_run`` is a **top-level** field — do not nest it inside ``metadata``. Extra/unknown fields are rejected (422) to prevent silent mis-use. """ model_config = ConfigDict(extra="forbid") dry_run: bool = False + builder_graph_id: str | None = Field(default=None, max_length=128) class CreateSessionResponse(BaseModel): @@ -318,29 +336,43 @@ async def create_session( user_id: Annotated[str, Security(auth.get_user_id)], request: CreateSessionRequest | None = None, ) -> CreateSessionResponse: - """ - Create a new chat session. + """Create (or get-or-create) a chat session. - Initiates a new chat session for the authenticated user. + Two modes, selected by the request body: + + - Default: create a fresh session for the user. ``dry_run=True`` forces + run_block and run_agent calls to use dry-run simulation. + - Builder-bound: when ``builder_graph_id`` is set, get-or-create keyed + on ``(user_id, builder_graph_id)``. Returns the existing session for + that graph or creates one locked to it. Graph ownership is validated + inside :func:`get_or_create_builder_session`; raises 404 on + unauthorized access. Write-side scope is enforced per-tool + (``edit_agent`` / ``run_agent`` reject any ``agent_id`` other than + the bound graph) and a small blacklist hides tools that conflict + with the panel's scope (see :data:`BUILDER_BLOCKED_TOOLS`). Args: user_id: The authenticated user ID parsed from the JWT (required). - request: Optional request body. When provided, ``dry_run=True`` - forces run_block and run_agent calls to use dry-run simulation. + request: Optional request body with ``dry_run`` and/or + ``builder_graph_id``. Returns: - CreateSessionResponse: Details of the created session. - + CreateSessionResponse: Details of the resulting session. """ dry_run = request.dry_run if request else False + builder_graph_id = request.builder_graph_id if request else None logger.info( f"Creating session with user_id: " f"...{user_id[-8:] if len(user_id) > 8 else ''}" f"{', dry_run=True' if dry_run else ''}" + f"{f', builder_graph_id={builder_graph_id}' if builder_graph_id else ''}" ) - session = await create_chat_session(user_id, dry_run=dry_run) + if builder_graph_id: + session = await get_or_create_builder_session(user_id, builder_graph_id) + else: + session = await create_chat_session(user_id, dry_run=dry_run) return CreateSessionResponse( id=session.session_id, @@ -838,7 +870,8 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - await _validate_and_get_session(session_id, user_id) + session = await _validate_and_get_session(session_id, user_id) + builder_permissions = resolve_session_permissions(session) # Self-defensive queue-fallback: if a turn is already running, don't race # it on the cluster lock — drop the message into the pending buffer and @@ -953,6 +986,7 @@ async def stream_chat_post( file_ids=sanitized_file_ids, mode=request.mode, model=request.model, + permissions=builder_permissions, request_arrival_at=request_arrival_at, ) else: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 88c4ef5f14..11dac08084 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -11,10 +11,20 @@ import pytest_mock from backend.api.features.chat import routes as chat_routes from backend.api.features.chat.routes import _strip_injected_context from backend.copilot.rate_limit import SubscriptionTier +from backend.util.exceptions import NotFoundError app = fastapi.FastAPI() app.include_router(chat_routes.router) + +@app.exception_handler(NotFoundError) +async def _not_found_handler( + request: fastapi.Request, exc: NotFoundError +) -> fastapi.responses.JSONResponse: + """Mirror the production NotFoundError → 404 mapping from the REST app.""" + return fastapi.responses.JSONResponse(status_code=404, content={"detail": str(exc)}) + + client = fastapi.testclient.TestClient(app) TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a" @@ -964,6 +974,618 @@ class TestStripInjectedContext: assert result["content"] == "hello" +# ─── message max_length validation ─────────────────────────────────── + + +def test_stream_chat_rejects_too_long_message(): + """A message exceeding max_length=64_000 must be rejected (422).""" + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_001, + }, + ) + assert response.status_code == 422 + + +def test_stream_chat_accepts_exactly_max_length_message( + mocker: pytest_mock.MockFixture, +): + """A message exactly at max_length=64_000 must be accepted.""" + _mock_stream_internals(mocker) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 0, SubscriptionTier.FREE), + ) + + response = client.post( + "/sessions/sess-1/stream", + json={ + "message": "x" * 64_000, + }, + ) + assert response.status_code == 200 + + +# ─── list_sessions ──────────────────────────────────────────────────── + + +def _make_session_info(session_id: str = "sess-1", title: str | None = "Test"): + """Build a minimal ChatSessionInfo-like mock.""" + from backend.copilot.model import ChatSessionInfo, ChatSessionMetadata + + return ChatSessionInfo( + session_id=session_id, + user_id=TEST_USER_ID, + title=title, + usage=[], + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + metadata=ChatSessionMetadata(), + ) + + +def test_list_sessions_returns_sessions(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions returns list of sessions with is_processing=False when Redis OK.""" + session = _make_session_info("sess-abc") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + # Redis pipeline returns "done" (not "running") for this session + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["done"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert len(data["sessions"]) == 1 + assert data["sessions"][0]["id"] == "sess-abc" + assert data["sessions"][0]["is_processing"] is False + + +def test_list_sessions_marks_running_as_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Sessions with Redis status='running' should have is_processing=True.""" + session = _make_session_info("sess-xyz") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mock_redis = MagicMock() + mock_pipe = MagicMock() + mock_pipe.hget = MagicMock(return_value=None) + mock_pipe.execute = AsyncMock(return_value=["running"]) + mock_redis.pipeline = MagicMock(return_value=mock_pipe) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is True + + +def test_list_sessions_redis_failure_defaults_to_not_processing( + mocker: pytest_mock.MockerFixture, +) -> None: + """Redis failures must be swallowed and sessions default to is_processing=False.""" + session = _make_session_info("sess-fallback") + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([session], 1), + ) + mocker.patch( + "backend.api.features.chat.routes.get_redis_async", + side_effect=Exception("Redis down"), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + assert response.json()["sessions"][0]["is_processing"] is False + + +def test_list_sessions_empty(mocker: pytest_mock.MockerFixture) -> None: + """GET /sessions with no sessions returns empty list without hitting Redis.""" + mocker.patch( + "backend.api.features.chat.routes.get_user_sessions", + new_callable=AsyncMock, + return_value=([], 0), + ) + + response = client.get("/sessions") + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 0 + assert data["sessions"] == [] + + +# ─── delete_session ─────────────────────────────────────────────────── + + +def test_delete_session_success(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 204 when deleted successfully.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=True, + ) + # Patch use_e2b_sandbox env-var to disable E2B so the route skips sandbox cleanup. + # Patching the Pydantic property directly doesn't work (Pydantic v2 intercepts + # attribute setting on BaseSettings instances and raises AttributeError). + mocker.patch.dict("os.environ", {"USE_E2B_SANDBOX": "false"}) + + response = client.delete("/sessions/sess-1") + + assert response.status_code == 204 + + +def test_delete_session_not_found(mocker: pytest_mock.MockerFixture) -> None: + """DELETE /sessions/{id} returns 404 when session not found or not owned.""" + mocker.patch( + "backend.api.features.chat.routes.delete_chat_session", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.delete("/sessions/sess-missing") + + assert response.status_code == 404 + + +# ─── cancel_session_task ────────────────────────────────────────────── + + +def _mock_validate_session( + mocker: pytest_mock.MockerFixture, *, session_id: str = "sess-1" +): + """Mock _validate_and_get_session to return a dummy session.""" + from backend.copilot.model import ChatSession + + dummy = ChatSession.new(TEST_USER_ID, dry_run=False) + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + new_callable=AsyncMock, + return_value=dummy, + ) + + +def test_cancel_session_no_active_task(mocker: pytest_mock.MockerFixture) -> None: + """Cancel returns cancelled=True with reason when no stream is active.""" + _mock_validate_session(mocker) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + data = response.json() + assert data["cancelled"] is True + assert data["reason"] == "no_active_session" + + +def test_cancel_session_enqueues_cancel_and_confirms( + mocker: pytest_mock.MockerFixture, +) -> None: + """Cancel enqueues cancel task and returns cancelled=True once stream stops.""" + from backend.copilot.stream_registry import ActiveSession + + _mock_validate_session(mocker) + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + stopped_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="completed", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.get_session = AsyncMock(return_value=stopped_session) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + mock_enqueue = mocker.patch( + "backend.api.features.chat.routes.enqueue_cancel_task", + new_callable=AsyncMock, + ) + + response = client.post("/sessions/sess-1/cancel") + + assert response.status_code == 200 + assert response.json()["cancelled"] is True + mock_enqueue.assert_called_once_with("sess-1") + + +# ─── session_assign_user ────────────────────────────────────────────── + + +def test_session_assign_user(mocker: pytest_mock.MockerFixture) -> None: + """PATCH /sessions/{id}/assign-user calls assign_user_to_session and returns ok.""" + mock_assign = mocker.patch( + "backend.api.features.chat.routes.chat_service.assign_user_to_session", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.patch("/sessions/sess-1/assign-user") + + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + mock_assign.assert_called_once_with("sess-1", TEST_USER_ID) + + +# ─── get_ttl_config ────────────────────────────────────────────────── + + +def test_get_ttl_config(mocker: pytest_mock.MockerFixture) -> None: + """GET /config/ttl returns correct TTL values derived from config.""" + mocker.patch.object(chat_routes.config, "stream_ttl", 300) + + response = client.get("/config/ttl") + + assert response.status_code == 200 + data = response.json() + assert data["stream_ttl_seconds"] == 300 + assert data["stream_ttl_ms"] == 300_000 + + +# ─── reset_copilot_usage ────────────────────────────────────────────── + + +def _mock_reset_internals( + mocker: pytest_mock.MockerFixture, + *, + cost: int = 100, + enable_credit: bool = True, + daily_limit: int = 10_000, + weekly_limit: int = 50_000, + tier: "SubscriptionTier" = SubscriptionTier.FREE, + daily_used: int = 10_001, + weekly_used: int = 1_000, + reset_count: int | None = 0, + acquire_lock: bool = True, + reset_daily: bool = True, + remaining_balance: int = 9_000, +): + """Set up all dependencies for reset_copilot_usage tests.""" + from backend.copilot.rate_limit import CoPilotUsageStatus, UsageWindow + + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", cost) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", enable_credit) + + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(daily_limit, weekly_limit, tier), + ) + resets_at = datetime.now(UTC) + timedelta(hours=1) + status = CoPilotUsageStatus( + daily=UsageWindow(used=daily_used, limit=daily_limit, resets_at=resets_at), + weekly=UsageWindow(used=weekly_used, limit=weekly_limit, resets_at=resets_at), + ) + mocker.patch( + "backend.api.features.chat.routes.get_usage_status", + new_callable=AsyncMock, + return_value=status, + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=reset_count, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=acquire_lock, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + mocker.patch( + "backend.api.features.chat.routes.reset_daily_usage", + new_callable=AsyncMock, + return_value=reset_daily, + ) + mocker.patch( + "backend.api.features.chat.routes.increment_daily_reset_count", + new_callable=AsyncMock, + ) + + mock_credit_model = MagicMock() + mock_credit_model.spend_credits = AsyncMock(return_value=remaining_balance) + mock_credit_model.top_up_credits = AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.get_user_credit_model", + new_callable=AsyncMock, + return_value=mock_credit_model, + ) + return mock_credit_model + + +def test_reset_usage_returns_400_when_cost_is_zero( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when rate_limit_reset_cost <= 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 0) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not available" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_credits_disabled( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when credit system is disabled.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", False) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "disabled" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_no_daily_limit( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily_limit is 0.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(0, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "nothing to reset" in response.json()["detail"].lower() + + +def test_reset_usage_returns_503_when_redis_unavailable( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 when Redis is unavailable for reset count.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=None, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + + +def test_reset_usage_returns_429_when_max_resets_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when max daily resets exceeded.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 2) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "resets" in response.json()["detail"].lower() + + +def test_reset_usage_returns_429_when_lock_not_acquired( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 429 when a concurrent reset is in progress.""" + mocker.patch.object(chat_routes.config, "rate_limit_reset_cost", 100) + mocker.patch.object(chat_routes.config, "max_daily_resets", 3) + mocker.patch.object(chat_routes.settings.config, "enable_credit", True) + mocker.patch( + "backend.api.features.chat.routes.get_global_rate_limits", + new_callable=AsyncMock, + return_value=(10_000, 50_000, SubscriptionTier.FREE), + ) + mocker.patch( + "backend.api.features.chat.routes.get_daily_reset_count", + new_callable=AsyncMock, + return_value=0, + ) + mocker.patch( + "backend.api.features.chat.routes.acquire_reset_lock", + new_callable=AsyncMock, + return_value=False, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 429 + assert "in progress" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_limit_not_reached( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when daily limit has not been reached.""" + _mock_reset_internals(mocker, daily_used=500, daily_limit=10_000) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "not reached" in response.json()["detail"].lower() + + +def test_reset_usage_returns_400_when_weekly_also_exhausted( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 400 when weekly limit is also exhausted.""" + _mock_reset_internals( + mocker, + daily_used=10_001, + daily_limit=10_000, + weekly_used=50_001, + weekly_limit=50_000, + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 400 + assert "weekly" in response.json()["detail"].lower() + + +def test_reset_usage_returns_402_when_insufficient_credits( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 402 when credits are insufficient.""" + from backend.util.exceptions import InsufficientBalanceError + + mock_credit = _mock_reset_internals(mocker) + mock_credit.spend_credits = AsyncMock( + side_effect=InsufficientBalanceError( + message="Insufficient balance", + user_id=TEST_USER_ID, + balance=0.0, + amount=100.0, + ) + ) + mocker.patch( + "backend.api.features.chat.routes.release_reset_lock", + new_callable=AsyncMock, + ) + + response = client.post("/usage/reset") + + assert response.status_code == 402 + + +def test_reset_usage_success(mocker: pytest_mock.MockerFixture) -> None: + """POST /usage/reset returns 200 with updated usage on success.""" + _mock_reset_internals(mocker, remaining_balance=8_900) + + response = client.post("/usage/reset") + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["credits_charged"] == 100 + assert data["remaining_balance"] == 8_900 + assert "daily" in data["usage"] + assert "weekly" in data["usage"] + + +def test_reset_usage_refunds_on_redis_failure( + mocker: pytest_mock.MockerFixture, +) -> None: + """POST /usage/reset returns 503 and refunds credits when Redis reset fails.""" + mock_credit = _mock_reset_internals(mocker, reset_daily=False) + + response = client.post("/usage/reset") + + assert response.status_code == 503 + # Credits should be refunded via top_up_credits + mock_credit.top_up_credits.assert_called_once() + + +# ─── resume_session_stream ─────────────────────────────────────────── + + +def test_resume_session_stream_no_active_session( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when no active session.""" + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(None, None)) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + +def test_resume_session_stream_no_subscriber_queue( + mocker: pytest_mock.MockerFixture, +) -> None: + """GET /sessions/{id}/stream returns 204 when subscribe_to_session returns None.""" + from backend.copilot.stream_registry import ActiveSession + + active_session = ActiveSession( + session_id="sess-1", + user_id=TEST_USER_ID, + tool_call_id="chat_stream", + tool_name="chat", + turn_id="turn-1", + status="running", + ) + mock_registry = MagicMock() + mock_registry.get_active_session = AsyncMock(return_value=(active_session, "1-0")) + mock_registry.subscribe_to_session = AsyncMock(return_value=None) + mocker.patch("backend.api.features.chat.routes.stream_registry", mock_registry) + + response = client.get("/sessions/sess-1/stream") + + assert response.status_code == 204 + + # ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── @@ -1063,3 +1685,119 @@ def test_get_session_returns_backward_paginated( assert data["oldest_sequence"] == 0 assert "forward_paginated" not in data assert "newest_sequence" not in data + + +# ─── POST /sessions with builder_graph_id (get-or-create) ────────────── + + +def test_create_session_with_builder_graph_id_uses_get_or_create( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``POST /sessions`` with ``builder_graph_id`` routes through + ``get_or_create_builder_session`` and returns a session bound to the graph.""" + from backend.copilot.model import ChatSession + + async def _fake_get_or_create(user_id: str, graph_id: str) -> ChatSession: + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=graph_id, + ) + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-1"}) + + assert response.status_code == 200 + body = response.json() + assert body["metadata"]["builder_graph_id"] == "graph-1" + assert body["metadata"]["dry_run"] is False + + +def test_create_session_with_builder_graph_id_returns_404_when_not_owned( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """``get_or_create_builder_session`` raises ``NotFoundError`` when the + user doesn't own the graph; the route must map that to HTTP 404.""" + + async def _fake_get_or_create(user_id: str, graph_id: str): + raise NotFoundError(f"Graph {graph_id} not found") + + mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + side_effect=_fake_get_or_create, + ) + + response = client.post("/sessions", json={"builder_graph_id": "graph-unauthorized"}) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +def test_create_session_without_builder_graph_id_creates_fresh( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + """With no ``builder_graph_id`` the endpoint falls through to the + default ``create_chat_session`` path — no get-or-create lookup.""" + from backend.copilot.model import ChatSession + + gorc = mocker.patch( + "backend.api.features.chat.routes.get_or_create_builder_session", + new_callable=AsyncMock, + ) + + async def _fake_create(user_id: str, *, dry_run: bool) -> ChatSession: + return ChatSession.new(user_id, dry_run=dry_run) + + mocker.patch( + "backend.api.features.chat.routes.create_chat_session", + new_callable=AsyncMock, + side_effect=_fake_create, + ) + + response = client.post("/sessions", json={"dry_run": True}) + + assert response.status_code == 200 + assert response.json()["metadata"]["dry_run"] is True + gorc.assert_not_called() + + +def test_create_session_rejects_unknown_fields( + test_user_id: str, +) -> None: + """Extra request fields are rejected (422) to prevent silent mis-use.""" + response = client.post("/sessions", json={"unexpected": "x"}) + assert response.status_code == 422 + + +def test_resolve_session_permissions_blocks_out_of_scope_tools() -> None: + """Builder-bound sessions return a blacklist of the three tools that + conflict with the panel's graph-bound scope. Regular sessions return + ``None`` so default (unrestricted) behaviour is preserved.""" + from backend.copilot.builder_context import BUILDER_BLOCKED_TOOLS + from backend.copilot.model import ChatSession + + unbound = ChatSession.new("u1", dry_run=False) + assert chat_routes.resolve_session_permissions(unbound) is None + + bound = ChatSession.new("u1", dry_run=False, builder_graph_id="g1") + perms = chat_routes.resolve_session_permissions(bound) + assert perms is not None + assert perms.tools_exclude is True # blacklist, not whitelist + assert sorted(perms.tools) == sorted(BUILDER_BLOCKED_TOOLS) + # Read-side lookups stay available — only write-scope / guide-dup are blocked. + assert "find_block" not in perms.tools + assert "find_agent" not in perms.tools + assert "search_docs" not in perms.tools + # The write tools (edit_agent / run_agent) are NOT blacklisted — they + # enforce scope per-tool via the builder_graph_id guard. + assert "edit_agent" not in perms.tools + assert "run_agent" not in perms.tools diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 1e01ea638f..0743b461c6 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -743,6 +743,7 @@ async def update_library_agent_version_and_settings( graph=agent_graph, hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, + builder_chat_session_id=library.settings.builder_chat_session_id, ) if updated_settings != library.settings: library = await update_library_agent( diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 2a26421c91..1cc29bd6d4 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -168,9 +168,31 @@ class BlockSchema(BaseModel): return cls.cached_jsonschema @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: + schema = cls.jsonschema() + if exclude_fields: + # Drop the excluded fields from both the properties and the + # ``required`` list so jsonschema doesn't flag them as missing. + # Used by the dry-run path to skip credentials validation while + # still validating the remaining block inputs. + schema = { + **schema, + "properties": { + k: v + for k, v in schema.get("properties", {}).items() + if k not in exclude_fields + }, + "required": [ + r for r in schema.get("required", []) if r not in exclude_fields + ], + } + data = {k: v for k, v in data.items() if k not in exclude_fields} return json.validate_with_jsonschema( - schema=cls.jsonschema(), + schema=schema, data={k: v for k, v in data.items() if v is not None}, ) @@ -717,11 +739,16 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): # (e.g. AgentExecutorBlock) get proper input validation. is_dry_run = getattr(kwargs.get("execution_context"), "dry_run", False) if is_dry_run: + # Credential fields may be absent (LLM-built agents often skip + # wiring them) or nullified earlier in the pipeline. Validate + # the non-credential inputs against a schema with those fields + # excluded — stripping only the data while keeping them in the + # ``required`` list would falsely report ``'credentials' is a + # required property``. cred_field_names = set(self.input_schema.get_credentials_fields().keys()) - non_cred_data = { - k: v for k, v in input_data.items() if k not in cred_field_names - } - if error := self.input_schema.validate_data(non_cred_data): + if error := self.input_schema.validate_data( + input_data, exclude_fields=cred_field_names + ): raise BlockInputError( message=f"Unable to execute block with invalid input data: {error}", block_name=self.name, diff --git a/autogpt_platform/backend/backend/blocks/perplexity.py b/autogpt_platform/backend/backend/blocks/perplexity.py index a8b137ce2b..abdbadef91 100644 --- a/autogpt_platform/backend/backend/blocks/perplexity.py +++ b/autogpt_platform/backend/backend/blocks/perplexity.py @@ -98,14 +98,23 @@ class PerplexityBlock(Block): return _sanitize_perplexity_model(v) @classmethod - def validate_data(cls, data: BlockInput) -> str | None: + def validate_data( + cls, + data: BlockInput, + exclude_fields: set[str] | None = None, + ) -> str | None: """Sanitize the model field before JSON schema validation so that invalid values are replaced with the default instead of raising a - BlockInputError.""" + BlockInputError. + + Signature matches ``BlockSchema.validate_data`` (including the + optional ``exclude_fields`` kwarg added for dry-run credential + bypass) so Pyright doesn't flag this as an incompatible override. + """ model_value = data.get("model") if model_value is not None: data["model"] = _sanitize_perplexity_model(model_value).value - return super().validate_data(data) + return super().validate_data(data, exclude_fields=exclude_fields) system_prompt: str = SchemaField( title="System Prompt", diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index f87ec05390..474a6834b1 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -31,6 +31,10 @@ from backend.copilot.baseline.reasoning import ( BaselineReasoningEmitter, reasoning_extra_body, ) +from backend.copilot.builder_context import ( + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) from backend.copilot.config import CopilotLlmModel, CopilotMode from backend.copilot.context import get_workspace_manager, set_execution_context from backend.copilot.graphiti.config import is_enabled_for_user @@ -1388,7 +1392,18 @@ async def stream_chat_completion_baseline( graphiti_enabled = await is_enabled_for_user(user_id) graphiti_supplement = get_graphiti_supplement() if graphiti_enabled else "" - system_prompt = base_system_prompt + SHARED_TOOL_NOTES + graphiti_supplement + # Append the builder-session block (graph id+name + full building guide) + # AFTER the shared supplements so the system prompt is byte-identical + # across turns of the same builder session — Claude's prompt cache keeps + # the ~20KB guide warm for the whole session. Empty string for + # non-builder sessions keeps the cross-user cache hot. + builder_session_suffix = await build_builder_system_prompt_suffix(session) + system_prompt = ( + base_system_prompt + + SHARED_TOOL_NOTES + + graphiti_supplement + + builder_session_suffix + ) # Warm context: pre-load relevant facts from Graphiti on first turn. # Use the pre-drain count so pending messages drained at turn start @@ -1472,6 +1487,26 @@ async def stream_chat_completion_baseline( # Do NOT append warm_ctx to user_message_for_transcript — it would # persist stale temporal context into the transcript for future turns. + # Inject the per-turn ```` prefix when the session is + # bound to a graph via ``metadata.builder_graph_id``. Runs on every + # user turn (not just the first) so the LLM always sees the live graph + # snapshot — if the user edits the graph between turns, the next turn + # carries the updated nodes/links. Only version + nodes + links here; + # the static guide + graph id live in the system prompt via + # ``build_builder_system_prompt_suffix`` (session-stable, prompt-cached). + # Prepended AFTER any // blocks + # — same trust tier as those server-injected prefixes. Not persisted to + # the transcript: the snapshot is stale-by-definition after the turn ends. + if is_user_message and session.metadata.builder_graph_id: + builder_block = await build_builder_context_turn_prefix(session, user_id) + if builder_block: + for msg in reversed(openai_messages): + if msg["role"] == "user": + existing = msg.get("content", "") + if isinstance(existing, str): + msg["content"] = builder_block + existing + break + # Append user message to transcript. # Always append when the message is present and is from the user, # even on duplicate-suppressed retries (is_new_message=False). diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index 4092206786..03a9ef99c9 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -1233,6 +1233,81 @@ class TestMidLoopPendingFlushOrdering: assert len(assistant_msgs) == 2 +class TestBuilderContextSplit: + """Cross-helper composition: the guide must land in the system prompt via + ``build_builder_system_prompt_suffix`` and NOT in the per-turn user prefix + via ``build_builder_context_turn_prefix``. + + The baseline service composes these two blocks on each turn, so a drift + here (guide leaking into both, or missing from both) would kill Claude's + prompt-cache hit rate for builder sessions. + """ + + @pytest.mark.asyncio + async def test_guide_lives_in_system_prompt_not_user_message(self): + from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, + ) + from backend.copilot.model import ChatSession + + session = MagicMock(spec=ChatSession) + session.session_id = "s" + session.metadata = MagicMock() + session.metadata.builder_graph_id = "graph-1" + + agent_json = { + "id": "graph-1", + "name": "Demo", + "version": 7, + "nodes": [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + } + ], + "links": [], + } + guide_body = "# UNIQUE_GUIDE_MARKER body" + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent_json), + ), + patch( + "backend.copilot.builder_context._load_guide", + return_value=guide_body, + ), + ): + suffix = await build_builder_system_prompt_suffix(session) + prefix = await build_builder_context_turn_prefix(session, "user-1") + + # System prompt suffix carries and the guide. + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert guide_body in suffix + # Dynamic bits must NOT be in the suffix — otherwise renames and + # cross-graph sessions invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "Demo" not in suffix + + # Per-turn prefix carries with the full live + # snapshot (id, name, version, nodes) but NEVER the guide. + assert f"<{BUILDER_CONTEXT_TAG}>" in prefix + assert 'id="graph-1"' in prefix + assert 'name="Demo"' in prefix + assert 'version="7"' in prefix + assert guide_body not in prefix + assert "" not in prefix + + # Guide appears in the combined on-the-wire payload exactly ONCE. + combined = suffix + "\n\n" + prefix + assert combined.count(guide_body) == 1 + + class TestApplyPromptCacheMarkers: """Tests for _apply_prompt_cache_markers — Anthropic ephemeral cache_control markers on baseline OpenRouter requests.""" diff --git a/autogpt_platform/backend/backend/copilot/builder_context.py b/autogpt_platform/backend/backend/copilot/builder_context.py new file mode 100644 index 0000000000..9f36350d1c --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context.py @@ -0,0 +1,217 @@ +"""Builder-session context helpers — split cacheable system prompt from +the volatile per-turn snapshot so Claude's prompt cache stays warm.""" + +from __future__ import annotations + +import logging +from typing import Any + +from backend.copilot.model import ChatSession +from backend.copilot.permissions import CopilotPermissions +from backend.copilot.tools.agent_generator import get_agent_as_json +from backend.copilot.tools.get_agent_building_guide import _load_guide + +logger = logging.getLogger(__name__) + + +BUILDER_CONTEXT_TAG = "builder_context" +BUILDER_SESSION_TAG = "builder_session" + + +# Tools hidden from builder-bound sessions: ``create_agent`` / +# ``customize_agent`` would mint a new graph (panel is bound to one), +# and ``get_agent_building_guide`` duplicates bytes already in the +# system-prompt suffix. Everything else (find_block, find_agent, …) +# stays available so the LLM can look up ids instead of hallucinating. +BUILDER_BLOCKED_TOOLS: tuple[str, ...] = ( + "create_agent", + "customize_agent", + "get_agent_building_guide", +) + + +def resolve_session_permissions( + session: ChatSession | None, +) -> CopilotPermissions | None: + """Blacklist :data:`BUILDER_BLOCKED_TOOLS` for builder-bound sessions, + return ``None`` (unrestricted) otherwise.""" + if session is None or not session.metadata.builder_graph_id: + return None + return CopilotPermissions( + tools=list(BUILDER_BLOCKED_TOOLS), + tools_exclude=True, + ) + + +# Caps — mirror the frontend ``serializeGraphForChat`` defaults so the +# server-side block stays within a practical token budget for large graphs. +_MAX_NODES = 100 +_MAX_LINKS = 200 + +_FETCH_FAILED_PREFIX = ( + f"<{BUILDER_CONTEXT_TAG}>\n" + f"fetch_failed\n" + f"\n\n" +) + +# Embedded in the cacheable suffix so the LLM picks the right run_agent +# dispatch mode without forcing the user to watch a long-blocking call. +_BUILDER_RUN_AGENT_GUIDANCE = ( + "You are operating inside the builder panel, not the standalone " + "copilot page. The builder page already subscribes to agent " + "executions the moment you return an execution_id, so for REAL " + "(non-dry) runs prefer `run_agent(dry_run=False, wait_for_result=0)` " + "— the user will see the run stream in the builder's execution panel " + "in-place and your turn ends immediately with the id. For DRY-RUNS " + "keep `dry_run=True, wait_for_result=120`: blocking is required so " + "you can inspect `execution.node_executions` and report the verdict " + "in the same turn." +) + + +def _sanitize_for_xml(value: Any) -> str: + """Escape XML special chars — mirrors ``sanitizeForXml`` in + ``BuilderChatPanel/helpers.ts``.""" + s = "" if value is None else str(value) + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def _node_display_name(node: dict[str, Any]) -> str: + """Prefer the user-set label (``input_default.name`` / ``metadata.title``); + fall back to the block id.""" + defaults = node.get("input_default") or {} + metadata = node.get("metadata") or {} + for key in ("name", "title", "label"): + value = defaults.get(key) or metadata.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + block_id = node.get("block_id") or "" + return block_id or "unknown" + + +def _format_nodes(nodes: list[dict[str, Any]]) -> str: + if not nodes: + return "\n" + visible = nodes[:_MAX_NODES] + lines = [] + for node in visible: + node_id = _sanitize_for_xml(node.get("id") or "") + name = _sanitize_for_xml(_node_display_name(node)) + block_id = _sanitize_for_xml(node.get("block_id") or "") + lines.append(f"- {node_id}: {name} ({block_id})") + extra = len(nodes) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +def _format_links( + links: list[dict[str, Any]], + nodes: list[dict[str, Any]], +) -> str: + if not links: + return "\n" + name_by_id = {n.get("id"): _node_display_name(n) for n in nodes} + visible = links[:_MAX_LINKS] + lines = [] + for link in visible: + src_id = link.get("source_id") or "" + dst_id = link.get("sink_id") or "" + src_name = name_by_id.get(src_id, src_id) + dst_name = name_by_id.get(dst_id, dst_id) + src_out = link.get("source_name") or "" + dst_in = link.get("sink_name") or "" + lines.append( + f"- {_sanitize_for_xml(src_name)}.{_sanitize_for_xml(src_out)} " + f"-> {_sanitize_for_xml(dst_name)}.{_sanitize_for_xml(dst_in)}" + ) + extra = len(links) - len(visible) + if extra > 0: + lines.append(f"({extra} more not shown)") + body = "\n".join(lines) + return f"\n{body}\n" + + +async def build_builder_system_prompt_suffix(session: ChatSession) -> str: + """Return the cacheable system-prompt suffix for a builder session. + + Holds only static content (dispatch guidance + building guide) so the + bytes are identical across turns AND across sessions for different + graphs — the live id/name/version ride on the per-turn prefix. + """ + if not session.metadata.builder_graph_id: + return "" + + try: + guide = _load_guide() + except Exception: + logger.exception("[builder_context] Failed to load agent-building guide") + return "" + + # The guide is trusted server-side content (read from disk). We do NOT + # escape it — the LLM needs the raw markdown to make sense of block ids, + # code fences, and example JSON. + return ( + f"\n\n<{BUILDER_SESSION_TAG}>\n" + f"\n" + f"{_BUILDER_RUN_AGENT_GUIDANCE}\n" + f"\n" + f"\n{guide}\n\n" + f"" + ) + + +async def build_builder_context_turn_prefix( + session: ChatSession, + user_id: str | None, +) -> str: + """Return the per-turn ```` prefix with the live + graph snapshot (id/name/version/nodes/links). ``""`` for non-builder + sessions; fetch-failure marker if the graph cannot be read.""" + graph_id = session.metadata.builder_graph_id + if not graph_id: + return "" + + try: + agent_json = await get_agent_as_json(graph_id, user_id) + except Exception: + logger.exception( + "[builder_context] Failed to fetch graph %s for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + if not agent_json: + logger.warning( + "[builder_context] Graph %s not found for session %s", + graph_id, + session.session_id, + ) + return _FETCH_FAILED_PREFIX + + version = _sanitize_for_xml(agent_json.get("version") or "") + raw_name = agent_json.get("name") + graph_name = ( + raw_name.strip() if isinstance(raw_name, str) and raw_name.strip() else None + ) + nodes = agent_json.get("nodes") or [] + links = agent_json.get("links") or [] + name_attr = f' name="{_sanitize_for_xml(graph_name)}"' if graph_name else "" + graph_tag = ( + f'' + ) + + inner = f"{graph_tag}\n{_format_nodes(nodes)}\n{_format_links(links, nodes)}" + return f"<{BUILDER_CONTEXT_TAG}>\n{inner}\n\n\n" diff --git a/autogpt_platform/backend/backend/copilot/builder_context_test.py b/autogpt_platform/backend/backend/copilot/builder_context_test.py new file mode 100644 index 0000000000..efeb6f7dad --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/builder_context_test.py @@ -0,0 +1,329 @@ +"""Tests for the split builder-context helpers. + +Covers both halves of the public API: + +- :func:`build_builder_system_prompt_suffix` — session-stable block + appended to the system prompt (contains the guide + graph id/name). +- :func:`build_builder_context_turn_prefix` — per-turn user-message + prefix (contains the live version + node/link snapshot). +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot.builder_context import ( + BUILDER_CONTEXT_TAG, + BUILDER_SESSION_TAG, + build_builder_context_turn_prefix, + build_builder_system_prompt_suffix, +) +from backend.copilot.model import ChatSession + + +def _session( + builder_graph_id: str | None, + *, + user_id: str = "test-user", +) -> ChatSession: + """Minimal ``ChatSession`` with *builder_graph_id* on metadata.""" + return ChatSession.new( + user_id, + dry_run=False, + builder_graph_id=builder_graph_id, + ) + + +def _agent_json( + nodes: list[dict] | None = None, + links: list[dict] | None = None, + **overrides, +) -> dict: + base: dict = { + "id": "graph-1", + "name": "My Agent", + "description": "A test agent", + "version": 3, + "is_active": True, + "nodes": nodes if nodes is not None else [], + "links": links if links is not None else [], + } + base.update(overrides) + return base + + +# --------------------------------------------------------------------------- +# build_builder_system_prompt_suffix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_system_prompt_suffix(session) + assert result == "" + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_contains_only_static_content(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix.startswith("\n\n") + assert f"<{BUILDER_SESSION_TAG}>" in suffix + assert f"" in suffix + assert "" in suffix + assert "# Guide body" in suffix + # Dispatch-mode guidance must appear so the LLM knows to prefer + # wait_for_result=0 for real runs (builder UI subscribes live) and + # wait_for_result=120 for dry-runs (so it can inspect the node trace). + assert "" in suffix + assert "wait_for_result=0" in suffix + assert "wait_for_result=120" in suffix + # Regression: dynamic graph id/name must NOT leak into the cacheable + # suffix — they live in the per-turn prefix so renames and cross-graph + # sessions don't invalidate Claude's prompt cache. + assert "graph-1" not in suffix + assert "id=" not in suffix + assert "name=" not in suffix + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_identical_across_graphs(): + """The suffix must be byte-identical regardless of which graph the + session is bound to — that's what keeps the cacheable prefix warm + across sessions.""" + s1 = _session("graph-1") + s2 = _session("graph-2", user_id="different-owner") + with patch( + "backend.copilot.builder_context._load_guide", + return_value="# Guide body", + ): + suffix_1 = await build_builder_system_prompt_suffix(s1) + suffix_2 = await build_builder_system_prompt_suffix(s2) + + assert suffix_1 == suffix_2 + + +@pytest.mark.asyncio +async def test_system_prompt_suffix_empty_when_guide_load_fails(): + """Guide load failure means we have nothing useful to add — emit an + empty suffix rather than a half-built block.""" + session = _session("graph-1") + with patch( + "backend.copilot.builder_context._load_guide", + side_effect=OSError("missing"), + ): + suffix = await build_builder_system_prompt_suffix(session) + + assert suffix == "" + + +# --------------------------------------------------------------------------- +# build_builder_context_turn_prefix +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_turn_prefix_empty_for_non_builder(): + session = _session(None) + result = await build_builder_context_turn_prefix(session, "user-1") + assert result == "" + + +@pytest.mark.asyncio +async def test_turn_prefix_contains_version_nodes_and_links(): + session = _session("graph-1") + nodes = [ + { + "id": "n1", + "block_id": "block-A", + "input_default": {"name": "Input"}, + "metadata": {}, + }, + { + "id": "n2", + "block_id": "block-B", + "input_default": {}, + "metadata": {}, + }, + ] + links = [ + { + "source_id": "n1", + "sink_id": "n2", + "source_name": "out", + "sink_name": "in", + } + ] + agent = _agent_json(nodes=nodes, links=links) + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=agent), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert block.startswith(f"<{BUILDER_CONTEXT_TAG}>\n") + assert block.endswith(f"\n\n") + assert 'id="graph-1"' in block + assert 'name="My Agent"' in block + assert 'version="3"' in block + assert 'node_count="2"' in block + assert 'edge_count="1"' in block + assert "n1: Input (block-A)" in block + assert "n2: block-B (block-B)" in block + assert "Input.out -> block-B.in" in block + + +@pytest.mark.asyncio +async def test_turn_prefix_does_not_include_guide(): + """The guide lives in the cacheable system prompt, not in the per-turn + prefix.""" + session = _session("graph-1") + with ( + patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json()), + ), + # Sentinel guide text — if it leaks into the turn prefix the + # assertion below catches it. + patch( + "backend.copilot.builder_context._load_guide", + return_value="SENTINEL_GUIDE_BODY", + ), + ): + block = await build_builder_context_turn_prefix(session, "user-1") + + assert "SENTINEL_GUIDE_BODY" not in block + assert "" not in block + + +@pytest.mark.asyncio +async def test_turn_prefix_escapes_graph_name(): + session = _session("graph-1") + with patch( + "backend.copilot.builder_context.get_agent_as_json", + new=AsyncMock(return_value=_agent_json(name='", - description: "", - hardcodedValues: {}, - inputSchema: {}, - outputSchema: {}, - uiType: 1, - block_id: "b1", - costs: [], - categories: [], - }, - type: "custom" as const, - position: { x: 0, y: 0 }, - }, - ] as unknown as CustomNode[]; - - const result = serializeGraphForChat(nodes, []); - expect(result).not.toContain("`; - const wrapped = wrapWithHeadInjection(content, tailwindScript); + const wrapped = wrapWithHeadInjection( + content, + tailwindScript + FRAGMENT_LINK_INTERCEPTOR_SCRIPT, + ); return (