From d82ecac363c6dc2cfa6bc1b92a4994750032e9a2 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Wed, 15 Apr 2026 14:50:34 +0700 Subject: [PATCH 1/5] fix(backend/copilot): null-safe token accumulation for OpenRouter null cache fields (#12789) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why OpenRouter occasionally returns `null` (not `0`) for `cache_read_input_tokens` and `cache_creation_input_tokens` on the initial streaming event, before real token counts are available. Python's `dict.get(key, 0)` only falls back to `0` when the key is **missing** — when the key exists with a `null` value, `.get(key, 0)` returns `None`. This causes `TypeError: unsupported operand type(s) for +=: 'int' and 'NoneType'` in the usage accumulator on the first streaming chunk from OpenRouter models. ## What - Replace `.get(key, 0)` with `.get(key) or 0` for all four token fields in `_run_stream_attempt` - Add `TestTokenUsageNullSafety` unit tests in `service_helpers_test.py` ## How Minimal targeted fix — only the four `+=` accumulation lines changed. No behaviour change for Anthropic-native models (they never emit null values). ## Checklist - [x] Tests cover null event, real event, absent keys, and multi-turn accumulation - [x] No behaviour change for Anthropic-native models - [x] No API changes --- .../backend/backend/copilot/sdk/service.py | 19 +++-- .../copilot/sdk/service_helpers_test.py | 85 +++++++++++++++++++ 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index f4aa019b08..c7d166adba 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -1865,15 +1865,20 @@ async def _run_stream_attempt( # cache_read_input_tokens = served from cache # cache_creation_input_tokens = written to cache if sdk_msg.usage: - state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0) - state.usage.cache_read_tokens += sdk_msg.usage.get( - "cache_read_input_tokens", 0 + # Use `or 0` instead of a default in .get() because + # OpenRouter may include the key with a null value (e.g. + # {"cache_read_input_tokens": null}) for models that don't + # yet report cache tokens, making .get("key", 0) return + # None rather than the fallback 0. + state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0 + state.usage.cache_read_tokens += ( + sdk_msg.usage.get("cache_read_input_tokens") or 0 ) - state.usage.cache_creation_tokens += sdk_msg.usage.get( - "cache_creation_input_tokens", 0 + state.usage.cache_creation_tokens += ( + sdk_msg.usage.get("cache_creation_input_tokens") or 0 ) - state.usage.completion_tokens += sdk_msg.usage.get( - "output_tokens", 0 + state.usage.completion_tokens += ( + sdk_msg.usage.get("output_tokens") or 0 ) logger.info( "%s Token usage: uncached=%d, cache_read=%d, " diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index be2c46bdbb..5f1487c43b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -21,6 +21,7 @@ from .service import ( _is_tool_only_message, _iter_sdk_messages, _reduce_context, + _TokenUsage, ) # --------------------------------------------------------------------------- @@ -350,3 +351,87 @@ class TestIsParallelContinuation: msg = MagicMock(spec=AssistantMessage) msg.content = [self._make_tool_block()] assert _is_tool_only_message(msg) is True + + +# --------------------------------------------------------------------------- +# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug) +# --------------------------------------------------------------------------- + + +class TestTokenUsageNullSafety: + """Verify that ResultMessage.usage dicts with null-valued cache fields + (as emitted by OpenRouter for the initial streaming event before real + token counts are available) do not crash the accumulator. + + Before the fix, dict.get("cache_read_input_tokens", 0) returned None + when the key existed with a null value, causing 'int += None' TypeError. + """ + + def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None: + """Mirror the production accumulation in sdk/service.py.""" + acc.prompt_tokens += usage.get("input_tokens") or 0 + acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0 + acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0 + acc.completion_tokens += usage.get("output_tokens") or 0 + + def test_null_cache_tokens_do_not_crash(self): + """OpenRouter initial event: cache keys present with null value.""" + usage = { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_input_tokens": None, + "cache_creation_input_tokens": None, + } + acc = _TokenUsage() + self._apply_usage(usage, acc) # must not raise TypeError + assert acc.prompt_tokens == 0 + assert acc.cache_read_tokens == 0 + assert acc.cache_creation_tokens == 0 + assert acc.completion_tokens == 0 + + def test_real_cache_tokens_are_accumulated(self): + """OpenRouter final event: real cache token counts are captured.""" + usage = { + "input_tokens": 10, + "output_tokens": 349, + "cache_read_input_tokens": 16600, + "cache_creation_input_tokens": 512, + } + acc = _TokenUsage() + self._apply_usage(usage, acc) + assert acc.prompt_tokens == 10 + assert acc.cache_read_tokens == 16600 + assert acc.cache_creation_tokens == 512 + assert acc.completion_tokens == 349 + + def test_absent_cache_keys_default_to_zero(self): + """Minimal usage dict without cache keys defaults correctly.""" + usage = {"input_tokens": 5, "output_tokens": 20} + acc = _TokenUsage() + self._apply_usage(usage, acc) + assert acc.prompt_tokens == 5 + assert acc.cache_read_tokens == 0 + assert acc.cache_creation_tokens == 0 + assert acc.completion_tokens == 20 + + def test_multi_turn_accumulation(self): + """Null event followed by real event: only real tokens counted.""" + null_event = { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_input_tokens": None, + "cache_creation_input_tokens": None, + } + real_event = { + "input_tokens": 10, + "output_tokens": 349, + "cache_read_input_tokens": 16600, + "cache_creation_input_tokens": 512, + } + acc = _TokenUsage() + self._apply_usage(null_event, acc) + self._apply_usage(real_event, acc) + assert acc.prompt_tokens == 10 + assert acc.cache_read_tokens == 16600 + assert acc.cache_creation_tokens == 512 + assert acc.completion_tokens == 349 From da18f372f7763511c749045c657b743502a7dcb3 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Wed, 15 Apr 2026 14:57:17 +0700 Subject: [PATCH 2/5] feat(backend/copilot): add for_agent_generation flag to find_block (#12787) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why When the agent generator LLM builds a graph, it may need to look up schema details for graph-only blocks like `AgentInputBlock`, `AgentOutputBlock`, or `OrchestratorBlock`. These blocks are correctly hidden from regular CoPilot `find_block` results (they can't run standalone), but that same filter was also preventing the LLM from discovering them when composing an agent graph. ## What Added a `for_agent_generation: bool = False` parameter to `FindBlockTool`. ## How - `for_agent_generation=false` (default): existing behaviour unchanged — graph-only blocks are filtered from both UUID lookups and text search results. - `for_agent_generation=true`: bypasses `COPILOT_EXCLUDED_BLOCK_TYPES` / `COPILOT_EXCLUDED_BLOCK_IDS` so the LLM can find and inspect schemas for INPUT, OUTPUT, ORCHESTRATOR, WEBHOOK, etc. blocks when building agent JSON. - MCP_TOOL blocks are still excluded even with `for_agent_generation=true` (they go through `run_mcp_tool`, not `find_block`). ## Checklist - [x] No new dependencies - [x] Backward compatible (default `false` preserves existing behaviour) - [x] No frontend changes --- .../copilot/sdk/agent_generation_guide.md | 14 +- .../backend/copilot/tools/find_block.py | 59 ++-- .../backend/copilot/tools/find_block_test.py | 267 +++++++++++++++++- 3 files changed, 307 insertions(+), 33 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md index 35b4a348b9..145354b704 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md +++ b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md @@ -34,9 +34,13 @@ Steps: always inspect the current graph first so you know exactly what to change. Avoid using `include_graph=true` with broad keyword searches, as fetching multiple graphs at once is expensive and consumes LLM context budget. -2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to +2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to search for relevant blocks. This returns block IDs, names, descriptions, - and full input/output schemas. + and full input/output schemas. The `for_agent_generation=true` flag is + required to surface graph-only blocks such as AgentInputBlock, + AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock, + and WebhookBlock and MCPToolBlock. (When running MCP tools interactively + in CoPilot outside agent generation, use `run_mcp_tool` instead.) 3. **Find library agents**: Call `find_library_agent` to discover reusable agents that can be composed as sub-agents via `AgentExecutorBlock`. 4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas: @@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents: ### Using MCP Tools (MCPToolBlock) +> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP +> tools as persistent nodes in an agent graph. When running MCP tools directly in +> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles +> server discovery and authentication interactively. Use `MCPToolBlock` here only +> when the user wants the MCP call baked into a reusable agent graph. + To use an MCP (Model Context Protocol) tool as a node in the agent: 1. The user must specify which MCP server URL and tool name they want 2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`) diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block.py b/autogpt_platform/backend/backend/copilot/tools/find_block.py index 0cbc3ba047..130e26562b 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block.py @@ -74,6 +74,15 @@ class FindBlockTool(BaseTool): "description": "Include full input/output schemas (for agent JSON generation).", "default": False, }, + "for_agent_generation": { + "type": "boolean", + "description": ( + "Set to true when searching for blocks to use inside an agent graph " + "(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). " + "Bypasses the CoPilot-only filter so graph-only blocks are visible." + ), + "default": False, + }, }, "required": ["query"], } @@ -88,6 +97,7 @@ class FindBlockTool(BaseTool): session: ChatSession, query: str = "", include_schemas: bool = False, + for_agent_generation: bool = False, **kwargs, ) -> ToolResponseBase: """Search for blocks matching the query. @@ -97,6 +107,8 @@ class FindBlockTool(BaseTool): session: Chat session query: Search query include_schemas: Whether to include block schemas in results + for_agent_generation: When True, bypasses the CoPilot exclusion filter + so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible. Returns: BlockListResponse: List of matching blocks @@ -123,34 +135,36 @@ class FindBlockTool(BaseTool): suggestions=["Search for an alternative block by name"], session_id=session_id, ) - if ( + is_excluded = ( block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES or block.id in COPILOT_EXCLUDED_BLOCK_IDS - ): - if block.block_type == BlockType.MCP_TOOL: + ) + if is_excluded: + # Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are + # exposed when building an agent graph so the LLM can inspect + # their schemas and wire them as nodes. In CoPilot direct use + # they are not executable — guide the LLM to the right tool. + if not for_agent_generation: + if block.block_type == BlockType.MCP_TOOL: + message = ( + f"Block '{block.name}' (ID: {block.id}) cannot be " + "run directly in CoPilot. Use run_mcp_tool for " + "interactive MCP execution, or call find_block with " + "for_agent_generation=true to embed it in an agent graph." + ) + else: + message = ( + f"Block '{block.name}' (ID: {block.id}) is not available " + "in CoPilot. It can only be used within agent graphs." + ) return NoResultsResponse( - message=( - f"Block '{block.name}' (ID: {block.id}) is not " - "runnable through find_block/run_block. Use " - "run_mcp_tool instead." - ), + message=message, suggestions=[ - "Use run_mcp_tool to discover and run this MCP tool", "Search for an alternative block by name", + "Use this block in an agent graph instead", ], session_id=session_id, ) - return NoResultsResponse( - message=( - f"Block '{block.name}' (ID: {block.id}) is not available " - "in CoPilot. It can only be used within agent graphs." - ), - suggestions=[ - "Search for an alternative block by name", - "Use this block in an agent graph instead", - ], - session_id=session_id, - ) # Check block-level permissions — hide denied blocks entirely perms = get_current_permissions() @@ -221,8 +235,9 @@ class FindBlockTool(BaseTool): if not block or block.disabled: continue - # Skip blocks excluded from CoPilot (graph-only blocks) - if ( + # Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are + # skipped in CoPilot direct use but surfaced for agent graph building. + if not for_agent_generation and ( block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES or block.id in COPILOT_EXCLUDED_BLOCK_IDS ): diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py index 64a7fe3788..d99672daa2 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py @@ -12,7 +12,7 @@ from .find_block import ( COPILOT_EXCLUDED_BLOCK_TYPES, FindBlockTool, ) -from .models import BlockListResponse +from .models import BlockListResponse, NoResultsResponse _TEST_USER_ID = "test-user-find-block" @@ -166,6 +166,194 @@ class TestFindBlockFiltering: assert len(response.blocks) == 1 assert response.blocks[0].id == "normal-block-id" + @pytest.mark.asyncio(loop_scope="session") + async def test_for_agent_generation_exposes_excluded_blocks_in_search(self): + """With for_agent_generation=True, excluded block types appear in search results.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "input-block-id", "score": 0.9}, + {"content_id": "output-block-id", "score": 0.8}, + ] + input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT) + output_block = make_mock_block( + "output-block-id", "Agent Output", BlockType.OUTPUT + ) + + def mock_get_block(block_id): + return { + "input-block-id": input_block, + "output-block-id": output_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="agent input", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + block_ids = {b.id for b in response.blocks} + assert "input-block-id" in block_ids + assert "output-block-id" in block_ids + + @pytest.mark.asyncio(loop_scope="session") + async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self): + """MCP_TOOL blocks appear in search results when for_agent_generation=True.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "mcp-block-id", "score": 0.9}, + {"content_id": "standard-block-id", "score": 0.8}, + ] + mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL) + standard_block = make_mock_block( + "standard-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + "mcp-block-id": mcp_block, + "standard-block-id": standard_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="mcp tool", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + assert any(b.id == "mcp-block-id" for b in response.blocks) + assert any(b.id == "standard-block-id" for b in response.blocks) + + @pytest.mark.asyncio(loop_scope="session") + async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self): + """MCP_TOOL blocks are excluded from search in normal CoPilot mode.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "mcp-block-id", "score": 0.9}, + {"content_id": "standard-block-id", "score": 0.8}, + ] + mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL) + standard_block = make_mock_block( + "standard-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + "mcp-block-id": mcp_block, + "standard-block-id": standard_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="mcp tool", + for_agent_generation=False, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 1 + assert response.blocks[0].id == "standard-block-id" + + @pytest.mark.asyncio(loop_scope="session") + async def test_for_agent_generation_exposes_excluded_ids_in_search(self): + """With for_agent_generation=True, excluded block IDs appear in search results.""" + session = make_session(user_id=_TEST_USER_ID) + orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS)) + + search_results = [ + {"content_id": orchestrator_id, "score": 0.9}, + {"content_id": "normal-block-id", "score": 0.8}, + ] + orchestrator_block = make_mock_block( + orchestrator_id, "Orchestrator", BlockType.STANDARD + ) + normal_block = make_mock_block( + "normal-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + orchestrator_id: orchestrator_block, + "normal-block-id": normal_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="orchestrator", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + block_ids = {b.id for b in response.blocks} + assert orchestrator_id in block_ids + assert "normal-block-id" in block_ids + @pytest.mark.asyncio(loop_scope="session") async def test_response_size_average_chars_per_block(self): """Measure average chars per block in the serialized response.""" @@ -549,8 +737,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) @pytest.mark.asyncio(loop_scope="session") @@ -571,8 +757,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "disabled" in response.message.lower() @@ -592,8 +776,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "not available" in response.message.lower() @@ -613,7 +795,74 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=orchestrator_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "not available" in response.message.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation( + self, + ): + """With for_agent_generation=True, excluded block types (INPUT) are visible.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert response.count == 1 + assert response.blocks[0].id == block_id + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self): + """MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert response.blocks[0].id == block_id + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self): + """MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=False, + ) + + assert isinstance(response, NoResultsResponse) + assert "run_mcp_tool" in response.message From f835674498fc4213912fff64f14feb7837f818a3 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Wed, 15 Apr 2026 15:37:11 +0700 Subject: [PATCH 3/5] feat(copilot): standard/advanced model toggle with Opus rate-limit multiplier (#12786) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why Users have different task complexity needs. Sonnet is fast and cheap for most queries; Opus is more capable for hard reasoning tasks. Exposing this as a simple toggle gives users control without requiring infrastructure complexity. Opus costs 5× more than Sonnet per Anthropic pricing ($15/$75 vs $3/$15 per M tokens). Rather than adding a separate entitlement gate, the rate-limit multiplier (5×) ensures Opus turns deplete the daily/weekly quota proportionally faster — users self-limit via their existing budget. ## What - **Standard/Advanced model toggle** in the chat input toolbar (sky-blue star icon, label only when active — matches the simulation DryRunToggleButton pattern but visually distinct) - **`CopilotLlmModel = Literal["standard", "advanced"]`** — model-agnostic tier names (not tied to Anthropic model names) - **Backend model resolution**: `"advanced"` → `claude-opus-4-6`, `"standard"` → `config.model` (currently Sonnet) - **Rate-limit multiplier**: Opus turns count as 5× in Redis token counters (daily + weekly limits). Does **not** affect `PlatformCostLog` or `cost_usd` — those use real API-reported values - **localStorage persistence** via `Key.COPILOT_MODEL` so preference survives page refresh - **`claude_agent_max_budget_usd`** reduced from $15 to $10 ## How ### Backend - `CopilotLlmModel` type added to `config.py`, imported in routes/executor/service - `stream_chat_completion_sdk` accepts `model: CopilotLlmModel | None` - Model tier resolved early in the SDK path; `_normalize_model_name` strips the OpenRouter provider prefix - `model_cost_multiplier` (1.0 or 5.0) computed from final resolved model name, passed to `persist_and_record_usage` → `record_token_usage` (Redis only) - No separate LD flag needed — rate limit is the gate ### Frontend - `ModelToggleButton` component: sky-blue, star icon, "Advanced" label when active - `copilotModel` state in `useCopilotUIStore` with localStorage hydration - `copilotModelRef` pattern in `useCopilotStream` (avoids recreating `DefaultChatTransport`) - Toggle gated behind `showModeToggle && !isStreaming` in `ChatInput` ## Checklist - [x] Tests added/updated (ModelToggleButton.test.tsx, service_helpers_test.py, token_tracking_test.py) - [x] Rate-limit multiplier only affects Redis counters, not cost tracking - [x] No new LD flag needed --- .../backend/api/features/chat/routes.py | 8 +- .../backend/backend/copilot/config.py | 11 ++- .../backend/copilot/executor/processor.py | 1 + .../backend/backend/copilot/executor/utils.py | 8 +- .../backend/backend/copilot/rate_limit.py | 13 ++- .../backend/copilot/sdk/p0_guardrails_test.py | 2 +- .../backend/backend/copilot/sdk/service.py | 65 +++++++++++++- .../copilot/sdk/service_helpers_test.py | 42 ++++++++++ .../backend/backend/copilot/token_tracking.py | 5 ++ .../backend/copilot/token_tracking_test.py | 1 + .../copilot/__tests__/store.test.ts | 70 ++++++++++++---- .../components/ChatInput/ChatInput.tsx | 38 +++++++-- .../ChatInput/__tests__/ChatInput.test.tsx | 84 +++++++++++++++++-- .../components/ModelToggleButton.tsx | 38 +++++++++ .../__tests__/ModelToggleButton.test.tsx | 36 ++++++++ .../src/app/(platform)/copilot/store.ts | 35 ++++++-- .../app/(platform)/copilot/useCopilotPage.ts | 6 +- .../(platform)/copilot/useCopilotStream.ts | 12 ++- .../frontend/src/app/api/openapi.json | 8 ++ .../src/services/storage/local-storage.ts | 1 + 20 files changed, 439 insertions(+), 45 deletions(-) create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx create mode 100644 autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index aa2dc85e15..f8c3e3b804 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from backend.copilot import service as chat_service from backend.copilot import stream_registry -from backend.copilot.config import ChatConfig, CopilotMode +from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn from backend.copilot.model import ( @@ -139,6 +139,11 @@ class StreamChatRequest(BaseModel): description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. " "If None, uses the server default (extended_thinking).", ) + model: CopilotLlmModel | None = Field( + default=None, + description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. " + "If None, the server applies per-user LD targeting then falls back to config.", + ) class CreateSessionRequest(BaseModel): @@ -891,6 +896,7 @@ async def stream_chat_post( context=request.context, file_ids=sanitized_file_ids, mode=request.mode, + model=request.model, ) setup_time = (time.perf_counter() - stream_start_time) * 1000 diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index cfbc6feef4..d5418bf872 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL # subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk. CopilotMode = Literal["fast", "extended_thinking"] +# Per-request model tier set by the frontend model toggle. +# 'standard' uses the global config default (currently Sonnet). +# 'advanced' forces the highest-capability model (currently Opus). +# None means no preference — falls through to LD per-user targeting, then config. +# Using tier names instead of model names keeps the contract model-agnostic. +CopilotLlmModel = Literal["standard", "advanced"] + class ChatConfig(BaseSettings): """Configuration for the chat system.""" @@ -163,12 +170,12 @@ class ChatConfig(BaseSettings): "CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.", ) claude_agent_max_budget_usd: float = Field( - default=15.0, + default=10.0, ge=0.01, le=1000.0, description="Maximum spend in USD per SDK query. The CLI attempts " "to wrap up gracefully when this budget is reached. " - "Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). " + "Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). " "Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.", ) claude_agent_max_thinking_tokens: int = Field( diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index cc83b2dd99..0266e57806 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -351,6 +351,7 @@ class CoPilotProcessor: context=entry.context, file_ids=entry.file_ids, mode=effective_mode, + model=entry.model, ) async for chunk in stream_registry.stream_and_publish( session_id=entry.session_id, diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py index 0f7d23d9ba..3256f94869 100644 --- a/autogpt_platform/backend/backend/copilot/executor/utils.py +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -9,7 +9,7 @@ import logging from pydantic import BaseModel -from backend.copilot.config import CopilotMode +from backend.copilot.config import CopilotLlmModel, CopilotMode from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig from backend.util.logging import TruncatedLogger, is_structured_logging_enabled @@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel): mode: CopilotMode | None = None """Autopilot mode override: 'fast' or 'extended_thinking'. None = server default.""" + model: CopilotLlmModel | None = None + """Per-request model tier: 'standard' or 'advanced'. None = server default.""" + class CancelCoPilotEvent(BaseModel): """Event to cancel a CoPilot operation.""" @@ -180,6 +183,7 @@ async def enqueue_copilot_turn( context: dict[str, str] | None = None, file_ids: list[str] | None = None, mode: CopilotMode | None = None, + model: CopilotLlmModel | None = None, ) -> None: """Enqueue a CoPilot task for processing by the executor service. @@ -192,6 +196,7 @@ async def enqueue_copilot_turn( context: Optional context for the message (e.g., {url: str, content: str}) file_ids: Optional workspace file IDs attached to the user's message mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default. + model: Per-request model tier ('standard' or 'advanced'). None = server default. """ from backend.util.clients import get_async_copilot_queue @@ -204,6 +209,7 @@ async def enqueue_copilot_turn( context=context, file_ids=file_ids, mode=mode, + model=model, ) queue_client = await get_async_copilot_queue() diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index f72d36de23..3124c28992 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -302,6 +302,7 @@ async def record_token_usage( *, cache_read_tokens: int = 0, cache_creation_tokens: int = 0, + model_cost_multiplier: float = 1.0, ) -> None: """Record token usage for a user across all windows. @@ -315,12 +316,17 @@ async def record_token_usage( ``prompt_tokens`` should be the *uncached* input count (``input_tokens`` from the API response). Cache counts are passed separately. + ``model_cost_multiplier`` scales the final weighted total to reflect + relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet) + so that Opus turns deplete the rate limit faster, proportional to cost. + Args: user_id: The user's ID. prompt_tokens: Uncached input tokens. completion_tokens: Output tokens. cache_read_tokens: Tokens served from prompt cache (10% cost). cache_creation_tokens: Tokens written to prompt cache (25% cost). + model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus). """ prompt_tokens = max(0, prompt_tokens) completion_tokens = max(0, completion_tokens) @@ -332,7 +338,9 @@ async def record_token_usage( + round(cache_creation_tokens * 0.25) + round(cache_read_tokens * 0.1) ) - total = weighted_input + completion_tokens + total = round( + (weighted_input + completion_tokens) * max(1.0, model_cost_multiplier) + ) if total <= 0: return @@ -340,11 +348,12 @@ async def record_token_usage( prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens ) logger.info( - "Recording token usage for %s: raw=%d, weighted=%d " + "Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx " "(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)", user_id[:8], raw_total, total, + model_cost_multiplier, prompt_tokens, cache_read_tokens, cache_creation_tokens, diff --git a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py index 7077337a79..9305320fea 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py @@ -207,7 +207,7 @@ class TestConfigDefaults: def test_max_budget_usd_default(self): cfg = _make_config() - assert cfg.claude_agent_max_budget_usd == 15.0 + assert cfg.claude_agent_max_budget_usd == 10.0 def test_max_thinking_tokens_default(self): cfg = _make_config() diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index c7d166adba..3b655ffd1b 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -56,7 +56,7 @@ from backend.executor.cluster_lock import AsyncClusterLock from backend.util.exceptions import NotFoundError from backend.util.settings import Settings -from ..config import ChatConfig, CopilotMode +from ..config import ChatConfig, CopilotLlmModel, CopilotMode from ..constants import ( COPILOT_ERROR_PREFIX, COPILOT_RETRYABLE_ERROR_PREFIX, @@ -132,6 +132,11 @@ _MAX_STREAM_ATTEMPTS = 3 # self-correct. The limit is generous to allow recovery attempts. _EMPTY_TOOL_CALL_LIMIT = 5 +# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet +# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus +# turns deplete quota proportionally faster. +_OPUS_COST_MULTIPLIER = 5.0 + # User-facing error shown when the empty-tool-call circuit breaker trips. _CIRCUIT_BREAKER_ERROR_MSG = ( "AutoPilot was unable to complete the tool call " @@ -674,6 +679,48 @@ def _resolve_fallback_model() -> str | None: return _normalize_model_name(raw) +async def _resolve_model_and_multiplier( + model: "CopilotLlmModel | None", + session_id: str, +) -> tuple[str | None, float]: + """Resolve the SDK model string and rate-limit cost multiplier for a turn. + + Priority (highest first): + 1. Explicit per-request ``model`` tier from the frontend toggle. + 2. Global config default (``_resolve_sdk_model()``). + + Returns a ``(sdk_model, cost_multiplier)`` pair. + ``sdk_model`` is ``None`` when the Claude Code subscription default applies. + ``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise. + """ + sdk_model = _resolve_sdk_model() + + if model == "advanced": + sdk_model = _normalize_model_name("anthropic/claude-opus-4-6") + logger.info( + "[SDK] [%s] Per-request model override: advanced (%s)", + session_id[:12] if session_id else "?", + sdk_model, + ) + return sdk_model, _OPUS_COST_MULTIPLIER + + if model == "standard": + # Reset to config default — respects subscription mode (None = CLI default). + sdk_model = _resolve_sdk_model() + logger.info( + "[SDK] [%s] Per-request model override: standard (%s)", + session_id[:12] if session_id else "?", + sdk_model or "subscription-default", + ) + return sdk_model, 1.0 + + # No per-request override; derive multiplier from final resolved model. + cost_multiplier = ( + _OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0 + ) + return sdk_model, cost_multiplier + + _MAX_TRANSIENT_BACKOFF_SECONDS = 30 @@ -2155,6 +2202,7 @@ async def stream_chat_completion_sdk( file_ids: list[str] | None = None, permissions: "CopilotPermissions | None" = None, mode: CopilotMode | None = None, + model: CopilotLlmModel | None = None, **_kwargs: Any, ) -> AsyncIterator[StreamBaseResponse]: """Stream chat completion using Claude Agent SDK. @@ -2165,6 +2213,9 @@ async def stream_chat_completion_sdk( saved to the SDK working directory for the Read tool. mode: Accepted for signature compatibility with the baseline path. The SDK path does not currently branch on this value. + model: Per-request model preference from the frontend toggle. + 'advanced' → Claude Opus; 'standard' → global config default. + Takes priority over per-user LaunchDarkly targeting. """ _ = mode # SDK path ignores the requested mode. @@ -2279,6 +2330,10 @@ async def stream_chat_completion_sdk( turn_cache_creation_tokens = 0 turn_cost_usd: float | None = None graphiti_enabled = False + # Defaults ensure the finally block can always reference these safely even when + # an early return (e.g. sdk_cwd error) skips their normal assignment below. + sdk_model: str | None = None + model_cost_multiplier: float = 1.0 # Make sure there is no more code between the lock acquisition and try-block. try: @@ -2490,7 +2545,10 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server(use_e2b=use_e2b) - sdk_model = _resolve_sdk_model() + # Resolve model and cost multiplier (request tier → config default). + sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier( + model, session_id + ) # Track SDK-internal compaction (PreCompact hook → start, next msg → end) compaction = CompactionTracker() @@ -3175,8 +3233,9 @@ async def stream_chat_completion_sdk( cache_creation_tokens=turn_cache_creation_tokens, log_prefix=log_prefix, cost_usd=turn_cost_usd, - model=config.model, + model=sdk_model or config.model, provider="anthropic", + model_cost_multiplier=model_cost_multiplier, ) # --- Persist session messages --- diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 5f1487c43b..9d8b4bb135 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -20,6 +20,7 @@ from .service import ( _is_prompt_too_long, _is_tool_only_message, _iter_sdk_messages, + _normalize_model_name, _reduce_context, _TokenUsage, ) @@ -353,6 +354,47 @@ class TestIsParallelContinuation: assert _is_tool_only_message(msg) is True +# --------------------------------------------------------------------------- +# _normalize_model_name — used by per-request model override +# --------------------------------------------------------------------------- + + +class TestNormalizeModelName: + """Unit tests for the model-name normalisation helper. + + The per-request model toggle calls _normalize_model_name with either + ``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for + 'standard'). These tests verify the OpenRouter/provider-prefix stripping + that keeps the value compatible with the Claude CLI. + """ + + def test_strips_anthropic_prefix(self): + assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" + + def test_strips_openai_prefix(self): + assert _normalize_model_name("openai/gpt-4o") == "gpt-4o" + + def test_strips_google_prefix(self): + assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash" + + def test_already_normalized_unchanged(self): + assert ( + _normalize_model_name("claude-sonnet-4-20250514") + == "claude-sonnet-4-20250514" + ) + + def test_empty_string_unchanged(self): + assert _normalize_model_name("") == "" + + def test_opus_model_roundtrip(self): + """The exact string used for the 'opus' toggle strips correctly.""" + assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" + + def test_sonnet_openrouter_model(self): + """Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly.""" + assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4" + + # --------------------------------------------------------------------------- # _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug) # --------------------------------------------------------------------------- diff --git a/autogpt_platform/backend/backend/copilot/token_tracking.py b/autogpt_platform/backend/backend/copilot/token_tracking.py index e84b64d449..19406ced93 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking.py @@ -96,6 +96,7 @@ async def persist_and_record_usage( cost_usd: float | str | None = None, model: str | None = None, provider: str = "open_router", + model_cost_multiplier: float = 1.0, ) -> int: """Persist token usage to session and record for rate limiting. @@ -109,6 +110,9 @@ async def persist_and_record_usage( log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]"). cost_usd: Optional cost for logging (float from SDK, str otherwise). provider: Cost provider name (e.g. "anthropic", "open_router"). + model_cost_multiplier: Relative model cost factor for rate limiting + (1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so + more expensive models deplete the rate limit proportionally faster. Returns: The computed total_tokens (prompt + completion; cache excluded). @@ -163,6 +167,7 @@ async def persist_and_record_usage( completion_tokens=completion_tokens, cache_read_tokens=cache_read_tokens, cache_creation_tokens=cache_creation_tokens, + model_cost_multiplier=model_cost_multiplier, ) except Exception as usage_err: logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err) diff --git a/autogpt_platform/backend/backend/copilot/token_tracking_test.py b/autogpt_platform/backend/backend/copilot/token_tracking_test.py index 04c7667368..11757ce541 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking_test.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking_test.py @@ -230,6 +230,7 @@ class TestRateLimitRecording: completion_tokens=50, cache_read_tokens=1000, cache_creation_tokens=200, + model_cost_multiplier=1.0, ) @pytest.mark.asyncio diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts index f993daf58d..fd95bbdb2c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, it, beforeEach, vi } from "vitest"; +import { describe, expect, it, beforeEach, afterEach, vi } from "vitest"; import { useCopilotUIStore } from "../store"; vi.mock("@sentry/nextjs", () => ({ @@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => { isNotificationsEnabled: false, isSoundEnabled: true, showNotificationDialog: false, - copilotMode: "extended_thinking", + copilotChatMode: "extended_thinking", + copilotLlmModel: "standard", }); }); @@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => { }); }); - describe("copilotMode", () => { + describe("copilotChatMode", () => { it("defaults to extended_thinking", () => { - expect(useCopilotUIStore.getState().copilotMode).toBe( + expect(useCopilotUIStore.getState().copilotChatMode).toBe( "extended_thinking", ); }); it("sets mode to fast", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - expect(useCopilotUIStore.getState().copilotMode).toBe("fast"); + useCopilotUIStore.getState().setCopilotChatMode("fast"); + expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast"); }); it("sets mode back to extended_thinking", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - useCopilotUIStore.getState().setCopilotMode("extended_thinking"); - expect(useCopilotUIStore.getState().copilotMode).toBe( + useCopilotUIStore.getState().setCopilotChatMode("fast"); + useCopilotUIStore.getState().setCopilotChatMode("extended_thinking"); + expect(useCopilotUIStore.getState().copilotChatMode).toBe( "extended_thinking", ); }); - it("does not persist mode to localStorage", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - expect(window.localStorage.getItem("copilot-mode")).toBeNull(); + it("persists mode to localStorage", () => { + useCopilotUIStore.getState().setCopilotChatMode("fast"); + expect(window.localStorage.getItem("copilot-mode")).toBe("fast"); + }); + }); + + describe("copilotLlmModel", () => { + it("defaults to standard", () => { + expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard"); + }); + + it("sets model to advanced", () => { + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); + expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced"); + }); + + it("persists model to localStorage", () => { + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); + expect(window.localStorage.getItem("copilot-model")).toBe("advanced"); }); }); describe("clearCopilotLocalData", () => { it("resets state and clears localStorage keys", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); + useCopilotUIStore.getState().setCopilotChatMode("fast"); + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); useCopilotUIStore.getState().setNotificationsEnabled(true); useCopilotUIStore.getState().toggleSound(); useCopilotUIStore.getState().addCompletedSession("s1"); @@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => { useCopilotUIStore.getState().clearCopilotLocalData(); const state = useCopilotUIStore.getState(); - expect(state.copilotMode).toBe("extended_thinking"); + expect(state.copilotChatMode).toBe("extended_thinking"); + expect(state.copilotLlmModel).toBe("standard"); expect(state.isNotificationsEnabled).toBe(false); expect(state.isSoundEnabled).toBe(true); expect(state.completedSessionIDs.size).toBe(0); @@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => { window.localStorage.getItem("copilot-notifications-enabled"), ).toBeNull(); expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull(); + expect(window.localStorage.getItem("copilot-mode")).toBeNull(); + expect(window.localStorage.getItem("copilot-model")).toBeNull(); expect( window.localStorage.getItem("copilot-completed-sessions"), ).toBeNull(); @@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => { }); }); }); + +describe("useCopilotUIStore localStorage initialisation", () => { + afterEach(() => { + vi.resetModules(); + window.localStorage.clear(); + }); + + it("reads fast chat mode from localStorage on store creation", async () => { + window.localStorage.setItem("copilot-mode", "fast"); + vi.resetModules(); + const { useCopilotUIStore: fresh } = await import("../store"); + expect(fresh.getState().copilotChatMode).toBe("fast"); + }); + + it("reads advanced model from localStorage on store creation", async () => { + window.localStorage.setItem("copilot-model", "advanced"); + vi.resetModules(); + const { useCopilotUIStore: fresh } = await import("../store"); + expect(fresh.getState().copilotLlmModel).toBe("advanced"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx index d1e1ca4f9d..b6fedb722e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx @@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react"; import { AttachmentMenu } from "./components/AttachmentMenu"; import { DryRunToggleButton } from "./components/DryRunToggleButton"; import { FileChips } from "./components/FileChips"; +import { ModelToggleButton } from "./components/ModelToggleButton"; import { ModeToggleButton } from "./components/ModeToggleButton"; import { RecordingButton } from "./components/RecordingButton"; import { RecordingIndicator } from "./components/RecordingIndicator"; @@ -50,16 +51,22 @@ export function ChatInput({ onDroppedFilesConsumed, hasSession = false, }: Props) { - const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } = - useCopilotUIStore(); + const { + copilotChatMode, + setCopilotChatMode, + copilotLlmModel, + setCopilotLlmModel, + isDryRun, + setIsDryRun, + } = useCopilotUIStore(); const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION); const showDryRunToggle = showModeToggle; const [files, setFiles] = useState([]); function handleToggleMode() { const next = - copilotMode === "extended_thinking" ? "fast" : "extended_thinking"; - setCopilotMode(next); + copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking"; + setCopilotChatMode(next); toast({ title: next === "fast" @@ -72,6 +79,21 @@ export function ChatInput({ }); } + function handleToggleModel() { + const next = copilotLlmModel === "advanced" ? "standard" : "advanced"; + setCopilotLlmModel(next); + toast({ + title: + next === "advanced" + ? "Switched to Advanced model" + : "Switched to Standard model", + description: + next === "advanced" + ? "Using the highest-capability model." + : "Using the balanced standard model.", + }); + } + function handleToggleDryRun() { const next = !isDryRun; setIsDryRun(next); @@ -198,10 +220,16 @@ export function ChatInput({ /> {showModeToggle && !isStreaming && ( )} + {showModeToggle && !isStreaming && ( + + )} {showDryRunToggle && (!hasSession || isDryRun) && ( { +const mockSetCopilotChatMode = vi.fn((mode: string) => { mockCopilotMode = mode; }); +let mockCopilotLlmModel = "standard"; +const mockSetCopilotLlmModel = vi.fn((model: string) => { + mockCopilotLlmModel = model; +}); + vi.mock("@/app/(platform)/copilot/store", () => ({ useCopilotUIStore: () => ({ - copilotMode: mockCopilotMode, - setCopilotMode: mockSetCopilotMode, + copilotChatMode: mockCopilotMode, + setCopilotChatMode: mockSetCopilotChatMode, + copilotLlmModel: mockCopilotLlmModel, + setCopilotLlmModel: mockSetCopilotLlmModel, initialPrompt: null, setInitialPrompt: vi.fn(), }), @@ -107,6 +114,7 @@ afterEach(() => { cleanup(); vi.clearAllMocks(); mockCopilotMode = "extended_thinking"; + mockCopilotLlmModel = "standard"; }); describe("ChatInput mode toggle", () => { @@ -141,7 +149,7 @@ describe("ChatInput mode toggle", () => { mockCopilotMode = "extended_thinking"; render(); fireEvent.click(screen.getByLabelText(/switch to fast mode/i)); - expect(mockSetCopilotMode).toHaveBeenCalledWith("fast"); + expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast"); }); it("toggles from fast to extended_thinking on click", () => { @@ -149,7 +157,7 @@ describe("ChatInput mode toggle", () => { mockCopilotMode = "fast"; render(); fireEvent.click(screen.getByLabelText(/switch to extended thinking/i)); - expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking"); + expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking"); }); it("hides toggle button when streaming", () => { @@ -187,3 +195,69 @@ describe("ChatInput mode toggle", () => { ); }); }); + +describe("ChatInput model toggle", () => { + it("renders model toggle button when flag is enabled", () => { + mockFlagValue = true; + render(); + expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined(); + }); + + it("does not render model toggle when flag is disabled", () => { + mockFlagValue = false; + render(); + expect( + screen.queryByLabelText(/switch to (advanced|standard) model/i), + ).toBeNull(); + }); + + it("toggles from standard to advanced on click", () => { + mockFlagValue = true; + mockCopilotLlmModel = "standard"; + render(); + fireEvent.click(screen.getByLabelText(/switch to advanced model/i)); + expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced"); + }); + + it("toggles from advanced to standard on click", () => { + mockFlagValue = true; + mockCopilotLlmModel = "advanced"; + render(); + fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard"); + }); + + it("hides model toggle when streaming", () => { + mockFlagValue = true; + render(); + expect( + screen.queryByLabelText(/switch to (advanced|standard) model/i), + ).toBeNull(); + }); + + it("shows a toast when switching to advanced", async () => { + const { toast } = await import("@/components/molecules/Toast/use-toast"); + mockFlagValue = true; + mockCopilotLlmModel = "standard"; + render(); + fireEvent.click(screen.getByLabelText(/switch to advanced model/i)); + expect(toast).toHaveBeenCalledWith( + expect.objectContaining({ + title: expect.stringMatching(/switched to advanced model/i), + }), + ); + }); + + it("shows a toast when switching to standard", async () => { + const { toast } = await import("@/components/molecules/Toast/use-toast"); + mockFlagValue = true; + mockCopilotLlmModel = "advanced"; + render(); + fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + expect(toast).toHaveBeenCalledWith( + expect.objectContaining({ + title: expect.stringMatching(/switched to standard model/i), + }), + ); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx new file mode 100644 index 0000000000..cb3bc25f4f --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx @@ -0,0 +1,38 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { Cpu } from "@phosphor-icons/react"; +import type { CopilotLlmModel } from "../../../store"; + +interface Props { + model: CopilotLlmModel; + onToggle: () => void; +} + +export function ModelToggleButton({ model, onToggle }: Props) { + const isAdvanced = model === "advanced"; + return ( + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx new file mode 100644 index 0000000000..a77cb5b6f4 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx @@ -0,0 +1,36 @@ +import { render, screen, fireEvent, cleanup } from "@testing-library/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { ModelToggleButton } from "../ModelToggleButton"; + +afterEach(cleanup); + +describe("ModelToggleButton", () => { + it("shows no label when model is standard", () => { + render(); + expect(screen.queryByText("Advanced")).toBeNull(); + }); + + it("shows Advanced label when model is advanced", () => { + render(); + expect(screen.getByText("Advanced")).toBeTruthy(); + }); + + it("calls onToggle when clicked", () => { + const onToggle = vi.fn(); + render(); + fireEvent.click(screen.getByRole("button")); + expect(onToggle).toHaveBeenCalledTimes(1); + }); + + it("sets aria-pressed=false for standard", () => { + render(); + const btn = screen.getByLabelText("Switch to Advanced model"); + expect(btn.getAttribute("aria-pressed")).toBe("false"); + }); + + it("sets aria-pressed=true for advanced", () => { + render(); + const btn = screen.getByLabelText("Switch to Standard model"); + expect(btn.getAttribute("aria-pressed")).toBe("true"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts index d63c0bd76a..d8dcbd132c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts @@ -53,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600; /** Autopilot response mode. */ export type CopilotMode = "extended_thinking" | "fast"; +/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */ +export type CopilotLlmModel = "standard" | "advanced"; + const isClient = typeof window !== "undefined"; function getPersistedWidth(): number { @@ -134,8 +137,12 @@ interface CopilotUIState { goBackArtifact: () => void; /** Autopilot mode: 'extended_thinking' (default) or 'fast'. */ - copilotMode: CopilotMode; - setCopilotMode: (mode: CopilotMode) => void; + copilotChatMode: CopilotMode; + setCopilotChatMode: (mode: CopilotMode) => void; + + /** Model tier: 'standard' (default) or 'advanced' (highest-capability). */ + copilotLlmModel: CopilotLlmModel; + setCopilotLlmModel: (model: CopilotLlmModel) => void; /** Developer dry-run mode: sessions created with dry_run=true. */ isDryRun: boolean; @@ -298,9 +305,22 @@ export const useCopilotUIStore = create((set) => ({ }; }), - copilotMode: "extended_thinking", - setCopilotMode: (mode) => { - set({ copilotMode: mode }); + copilotChatMode: (() => { + const saved = isClient ? storage.get(Key.COPILOT_MODE) : null; + return saved === "fast" ? "fast" : "extended_thinking"; + })(), + setCopilotChatMode: (mode) => { + storage.set(Key.COPILOT_MODE, mode); + set({ copilotChatMode: mode }); + }, + + copilotLlmModel: (() => { + const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null; + return saved === "advanced" ? "advanced" : "standard"; + })(), + setCopilotLlmModel: (model) => { + storage.set(Key.COPILOT_MODEL, model); + set({ copilotLlmModel: model }); }, isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true", @@ -322,6 +342,8 @@ export const useCopilotUIStore = create((set) => ({ storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH); storage.clean(Key.COPILOT_COMPLETED_SESSIONS); storage.clean(Key.COPILOT_DRY_RUN); + storage.clean(Key.COPILOT_MODE); + storage.clean(Key.COPILOT_MODEL); set({ completedSessionIDs: new Set(), isNotificationsEnabled: false, @@ -334,7 +356,8 @@ export const useCopilotUIStore = create((set) => ({ activeArtifact: null, history: [], }, - copilotMode: "extended_thinking", + copilotChatMode: "extended_thinking", + copilotLlmModel: "standard", isDryRun: false, }); if (isClient) { diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index f8b0387c6b..01302c9f81 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -42,7 +42,8 @@ export function useCopilotPage() { setSessionToDelete, isDrawerOpen, setDrawerOpen, - copilotMode, + copilotChatMode, + copilotLlmModel, isDryRun, } = useCopilotUIStore(); @@ -78,7 +79,8 @@ export function useCopilotPage() { hydratedMessages, hasActiveStream, refetchSession, - copilotMode: isModeToggleEnabled ? copilotMode : undefined, + copilotMode: isModeToggleEnabled ? copilotChatMode : undefined, + copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined, }); const { olderMessages, hasMore, isLoadingMore, loadMore } = diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 918047d3d8..14ea672bfb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -18,7 +18,7 @@ import { resolveInProgressTools, getSendSuppressionReason, } from "./helpers"; -import type { CopilotMode } from "./store"; +import type { CopilotLlmModel, CopilotMode } from "./store"; const RECONNECT_BASE_DELAY_MS = 1_000; const RECONNECT_MAX_ATTEMPTS = 3; @@ -33,6 +33,8 @@ interface UseCopilotStreamArgs { refetchSession: () => Promise<{ data?: unknown }>; /** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */ copilotMode: CopilotMode | undefined; + /** Model tier override. `undefined` = let backend decide. */ + copilotModel: CopilotLlmModel | undefined; } export function useCopilotStream({ @@ -41,17 +43,20 @@ export function useCopilotStream({ hasActiveStream, refetchSession, copilotMode, + copilotModel, }: UseCopilotStreamArgs) { const queryClient = useQueryClient(); const [rateLimitMessage, setRateLimitMessage] = useState(null); function dismissRateLimit() { setRateLimitMessage(null); } - // Use a ref for copilotMode so the transport closure always reads the - // latest value without recreating the DefaultChatTransport (which would + // Use refs for copilotMode and copilotModel so the transport closure always reads + // the latest value without recreating the DefaultChatTransport (which would // reset useChat's internal Chat instance and break mid-session streaming). const copilotModeRef = useRef(copilotMode); copilotModeRef.current = copilotMode; + const copilotModelRef = useRef(copilotModel); + copilotModelRef.current = copilotModel; // Connect directly to the Python backend for SSE, bypassing the Next.js // serverless proxy. This eliminates the Vercel 800s function timeout that @@ -83,6 +88,7 @@ export function useCopilotStream({ context: null, file_ids: fileIds && fileIds.length > 0 ? fileIds : null, mode: copilotModeRef.current ?? null, + model: copilotModelRef.current ?? null, }, headers: await getCopilotAuthHeaders(), }; diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 732ef569d9..32e91bfd51 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -13931,6 +13931,14 @@ ], "title": "Mode", "description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)." + }, + "model": { + "anyOf": [ + { "type": "string", "enum": ["standard", "advanced"] }, + { "type": "null" } + ], + "title": "Model", + "description": "Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. If None, the server applies per-user LD targeting then falls back to config." } }, "type": "object", diff --git a/autogpt_platform/frontend/src/services/storage/local-storage.ts b/autogpt_platform/frontend/src/services/storage/local-storage.ts index de31967d53..b5c0392ecd 100644 --- a/autogpt_platform/frontend/src/services/storage/local-storage.ts +++ b/autogpt_platform/frontend/src/services/storage/local-storage.ts @@ -17,6 +17,7 @@ export enum Key { COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed", COPILOT_ARTIFACT_PANEL_WIDTH = "copilot-artifact-panel-width", COPILOT_MODE = "copilot-mode", + COPILOT_MODEL = "copilot-model", COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions", COPILOT_DRY_RUN = "copilot-dry-run", } From 0284614df06d3e12e206c31fea3ea722af510cb4 Mon Sep 17 00:00:00 2001 From: Ubbe Date: Wed, 15 Apr 2026 16:50:19 +0700 Subject: [PATCH 4/5] fix(copilot): abort SSE stream and disconnect backend listeners on session switch (#12766) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fixes stream disconnection bugs where the UI shows "running" with no output when users switch between copilot chat sessions. The root cause is that the old SSE fetch is not aborted and backend XREAD listeners keep running until timeout when switching sessions. ### Changes **Frontend (`useCopilotStream.ts`, `helpers.ts`)** - Call `sdkStop()` on session switch to abort the in-flight SSE fetch from the old session's transport - Fire-and-forget `DELETE` to new backend disconnect endpoint so server-side listeners release immediately - Store `resumeStream` and `sdkStop` in refs to fix stale closure bugs in: - Wake re-sync visibility handler (could call stale `resumeStream` after tab sleep) - Reconnect timer callback (could target wrong session's transport) - Resume effect (captured stale `resumeStream` during rapid session switches) **Backend (`stream_registry.py`, `routes.py`)** - Add `disconnect_all_listeners(session_id)` to stream registry — iterates active listener tasks, cancels any matching the session - Add `DELETE /sessions/{session_id}/stream` endpoint — auth-protected, calls `disconnect_all_listeners`, returns 204 ### Why Reported by multiple team members: when using Autopilot for anything serious, the frontend loses the SSE connection — particularly when switching between conversations. The backend completes fine (refreshing shows full output), but the UI gets stuck showing "running". This is the worst UX bug we have right now because real users will never know to refresh. ### How to test 1. Start a long-running autopilot task (e.g., "build a snake game") 2. While it's streaming, switch to a different chat session 3. Switch back — the UI should correctly show the completed output or resume the stream 4. Verify no "stuck running" state ## Test plan - [ ] Manual: switch sessions during active stream — no stuck "running" state - [ ] Manual: background tab for >30s during stream, return — wake re-sync works - [ ] Manual: trigger reconnect (kill network briefly) — reconnects to correct session - [ ] Verify: `pnpm lint`, `pnpm types`, `poetry run lint` all pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: majdyz --- .../backend/api/features/chat/routes.py | 25 ++++ .../backend/api/features/chat/routes_test.py | 45 +++++++ .../backend/copilot/stream_registry.py | 47 ++++++++ .../backend/copilot/stream_registry_test.py | 110 ++++++++++++++++++ .../src/app/(platform)/copilot/helpers.ts | 15 ++- .../(platform)/copilot/useCopilotStream.ts | 60 +++++++--- .../frontend/src/app/api/openapi.json | 29 +++++ 7 files changed, 313 insertions(+), 18 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/stream_registry_test.py diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index f8c3e3b804..ac7325e201 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -381,6 +381,31 @@ async def delete_session( return Response(status_code=204) +@router.delete( + "/sessions/{session_id}/stream", + dependencies=[Security(auth.requires_user)], + status_code=204, +) +async def disconnect_session_stream( + session_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> Response: + """Disconnect all active SSE listeners for a session. + + Called by the frontend when the user switches away from a chat so the + backend releases XREAD listeners immediately rather than waiting for + the 5-10 s timeout. + """ + session = await get_chat_session(session_id, user_id) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session {session_id} not found or access denied", + ) + await stream_registry.disconnect_all_listeners(session_id) + return Response(status_code=204) + + @router.patch( "/sessions/{session_id}/title", summary="Update session title", diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index f3896c7098..74259b3463 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -677,3 +677,48 @@ class TestStripInjectedContext: result = _strip_injected_context(msg) # Without a role, the helper short-circuits without touching content. assert result["content"] == "hello" + + +# ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── + + +def test_disconnect_stream_returns_204_and_awaits_registry( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mock_session = MagicMock() + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.delete("/sessions/sess-1/stream") + + assert response.status_code == 204 + mock_disconnect.assert_awaited_once_with("sess-1") + + +def test_disconnect_stream_returns_404_when_session_missing( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=None, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + ) + + response = client.delete("/sessions/unknown-session/stream") + + assert response.status_code == 404 + mock_disconnect.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 163b8c1bab..030763dbca 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -1149,3 +1149,50 @@ async def unsubscribe_from_session( ) logger.debug(f"Successfully unsubscribed from session {session_id}") + + +async def disconnect_all_listeners(session_id: str) -> int: + """Cancel every active listener task for *session_id*. + + Called when the frontend switches away from a session and wants the + backend to release resources immediately rather than waiting for the + XREAD timeout. + + Scope / limitations (best-effort optimisation, not a correctness primitive): + - Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request + lands on a different worker than the one serving the SSE, no listener + is cancelled here — the SSE worker still releases on its XREAD timeout. + - Session-scoped (not subscriber-scoped): cancels every active listener + for the session on this pod. In the rare case a single user opens two + SSE connections to the same session on the same pod (e.g. two tabs), + both would be torn down. Cross-pod, subscriber-scoped cancellation + would require a Redis pub/sub fan-out with per-listener tokens; that + is not implemented here because the XREAD timeout already bounds the + worst case. + + Returns the number of listener tasks that were cancelled. + """ + to_cancel: list[tuple[int, asyncio.Task]] = [ + (qid, task) + for qid, (sid, task) in list(_listener_sessions.items()) + if sid == session_id and not task.done() + ] + + for qid, task in to_cancel: + _listener_sessions.pop(qid, None) + task.cancel() + + cancelled = 0 + for _qid, task in to_cancel: + try: + await asyncio.wait_for(task, timeout=5.0) + except asyncio.CancelledError: + cancelled += 1 + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(f"Error cancelling listener for session {session_id}: {e}") + + if cancelled: + logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}") + return cancelled diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py new file mode 100644 index 0000000000..a09940a4a8 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -0,0 +1,110 @@ +"""Tests for disconnect_all_listeners in stream_registry.""" + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot import stream_registry + + +@pytest.fixture(autouse=True) +def _clear_listener_sessions(): + stream_registry._listener_sessions.clear() + yield + stream_registry._listener_sessions.clear() + + +async def _sleep_forever(): + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + raise + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_cancels_matching_session(): + task_a = asyncio.create_task(_sleep_forever()) + task_b = asyncio.create_task(_sleep_forever()) + task_other = asyncio.create_task(_sleep_forever()) + + stream_registry._listener_sessions[1] = ("sess-1", task_a) + stream_registry._listener_sessions[2] = ("sess-1", task_b) + stream_registry._listener_sessions[3] = ("sess-other", task_other) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 2 + assert task_a.cancelled() + assert task_b.cancelled() + assert not task_other.done() + # Matching entries are removed, non-matching entries remain. + assert 1 not in stream_registry._listener_sessions + assert 2 not in stream_registry._listener_sessions + assert 3 in stream_registry._listener_sessions + finally: + task_other.cancel() + try: + await task_other + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_no_match_returns_zero(): + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-other", task) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-missing") + + assert cancelled == 0 + assert not task.done() + assert 1 in stream_registry._listener_sessions + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_skips_already_done_tasks(): + async def _noop(): + return None + + done_task = asyncio.create_task(_noop()) + await done_task + stream_registry._listener_sessions[1] = ("sess-1", done_task) + + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + # Done tasks are filtered out before cancellation. + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_empty_registry(): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_timeout_not_counted(): + """Tasks that don't respond to cancellation (timeout) are not counted.""" + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-1", task) + + with patch.object( + asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError) + ): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 0 + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 66c437eb86..34e2bea51a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation"; import { getWebSocketToken } from "@/lib/supabase/actions"; import type { UIMessage } from "ai"; +import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat"; + export const ORIGINAL_TITLE = "AutoGPT"; /** @@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend( } /** - * Deduplicate messages by ID and by content fingerprint. + * Fire-and-forget: tell the backend to release XREAD listeners for a session. + * + * Called on session switch so the backend doesn't wait for its 5-10 s timeout + * before cleaning up. Failures are silently ignored — the backend will + * eventually clean up on its own. + */ +export function disconnectSessionStream(sessionId: string): void { + deleteV2DisconnectSessionStream(sessionId).catch(() => {}); +} + +/** + * Deduplicate messages by ID and by consecutive content fingerprint. * * ID dedup catches exact duplicates within the same source. * Content dedup uses a composite key of `role + preceding-user-message-id + diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 14ea672bfb..85709f23d9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -17,6 +17,7 @@ import { hasActiveBackendStream, resolveInProgressTools, getSendSuppressionReason, + disconnectSessionStream, } from "./helpers"; import type { CopilotLlmModel, CopilotMode } from "./store"; @@ -153,16 +154,15 @@ export function useCopilotStream({ reconnectTimerRef.current = setTimeout(() => { isReconnectScheduledRef.current = false; setIsReconnectScheduled(false); - // Strip any stale in-progress assistant message before resuming. - // The backend replays from "0-0", so the partial message would - // otherwise sit alongside the fully-replayed version. + // Strip the stale in-progress assistant message before resuming — + // the backend replays from "0-0", so keeping it would duplicate parts. setMessages((prev) => { if (prev.length > 0 && prev[prev.length - 1].role === "assistant") { return prev.slice(0, -1); } return prev; }); - resumeStream(); + resumeStreamRef.current(); }, delay); } @@ -260,6 +260,14 @@ export function useCopilotStream({ }, }); + // Keep stable refs to sdkStop and resumeStream so that async callbacks + // (session-switch cleanup, wake re-sync, reconnect timer) always call the + // latest version without stale-closure bugs. + const sdkStopRef = useRef(sdkStop); + sdkStopRef.current = sdkStop; + const resumeStreamRef = useRef(resumeStream); + resumeStreamRef.current = resumeStream; + // Wrap sdkSendMessage to guard against re-sending the user message during a // reconnect cycle. If the session already has the message (i.e. we are in a // reconnect/resume flow), only GET-resume is safe — never re-POST. @@ -386,7 +394,7 @@ export function useCopilotStream({ } return prev; }); - await resumeStream(); + await resumeStreamRef.current(); } // If !backendActive, the refetch will update hydratedMessages via // React Query, and the hydration effect below will merge them in. @@ -409,7 +417,7 @@ export function useCopilotStream({ return () => { document.removeEventListener("visibilitychange", onVisibilityChange); }; - }, [refetchSession, setMessages, resumeStream]); + }, [refetchSession, setMessages]); // Hydrate messages from REST API when not actively streaming useEffect(() => { @@ -425,8 +433,34 @@ export function useCopilotStream({ // Track resume state per session const hasResumedRef = useRef>(new Map()); - // Clean up reconnect state on session switch + // Clean up reconnect state on session switch. + // Abort the old stream's in-flight fetch and tell the backend to release + // its XREAD listeners immediately (fire-and-forget). + const prevStreamSessionRef = useRef(sessionId); useEffect(() => { + const prevSid = prevStreamSessionRef.current; + prevStreamSessionRef.current = sessionId; + + const isSwitching = Boolean(prevSid && prevSid !== sessionId); + if (isSwitching) { + // Mark BEFORE stopping so the old stream's async onError (which fires + // after the abort) sees the flag and short-circuits the reconnect path. + // Without this, the AbortError can queue a reconnect against the new + // session's `sessionId` (captured in the fresh onError closure). + isUserStoppingRef.current = true; + sdkStopRef.current(); + disconnectSessionStream(prevSid!); + // Schedule the reset as a task (not a microtask) so it runs AFTER the + // aborted fetch's onError has fired — otherwise the new session would + // be stuck with the "user stopping" flag set, preventing auto-resume + // when hydration detects an active backend stream. + setTimeout(() => { + isUserStoppingRef.current = false; + }, 0); + } else { + isUserStoppingRef.current = false; + } + clearTimeout(reconnectTimerRef.current); reconnectTimerRef.current = undefined; reconnectAttemptsRef.current = 0; @@ -434,7 +468,6 @@ export function useCopilotStream({ setIsReconnectScheduled(false); setRateLimitMessage(null); hasShownDisconnectToast.current = false; - isUserStoppingRef.current = false; lastSubmittedMsgRef.current = null; setReconnectExhausted(false); setIsSyncing(false); @@ -501,15 +534,8 @@ export function useCopilotStream({ return prev; }); - resumeStream(); - }, [ - sessionId, - hasActiveStream, - hydratedMessages, - status, - resumeStream, - setMessages, - ]); + resumeStreamRef.current(); + }, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]); // Clear messages when session is null useEffect(() => { diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 32e91bfd51..f93caabbb1 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1606,6 +1606,35 @@ } }, "/api/chat/sessions/{session_id}/stream": { + "delete": { + "tags": ["v2", "chat", "chat"], + "summary": "Disconnect Session Stream", + "description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.", + "operationId": "deleteV2DisconnectSessionStream", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Session Id" } + } + ], + "responses": { + "204": { "description": "Successful Response" }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, "get": { "tags": ["v2", "chat", "chat"], "summary": "Resume Session Stream", From 227c60abd31cdecfa141f7f96e89b620f0b02667 Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Wed, 15 Apr 2026 18:54:59 +0700 Subject: [PATCH 5/5] fix(backend/copilot): idempotency guard + frontend dedup fix for duplicate messages (#12788) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why After merging #12782 to dev, a k8s rolling deployment triggered infrastructure-level POST retries — nginx detected the old pod's connection reset mid-stream and resent the same POST to a new pod. Both pods independently saved the user message and ran the executor, producing duplicate entries in the DB (seq 159, 161, 163) and a duplicate response in the chat. The model saw the same question 3× in its context window and spent its response commenting on that instead of answering. Two compounding issues: 1. **No backend idempotency**: `append_and_save_message` saves unconditionally — k8s/nginx retries silently produce duplicate turns. 2. **Frontend dedup cleared after success**: `lastSubmittedMsgRef.current = null` after every completed turn wipes the dedup guard, so any rapid re-submit of the same text (from a stalled UI or user double-click) slips through. ## What **Backend** — Redis idempotency gate in `stream_chat_post`: - Before saving the user message, compute `sha256(session_id + message)[:16]` and `SET NX ex=30` in Redis - If key already exists → duplicate: return empty SSE (`StreamFinish + [DONE]`) immediately, skip save + executor enqueue - User messages only (`is_user_message=True`); system/assistant messages bypass the check **Frontend** — Keep `lastSubmittedMsgRef` populated after success: - Remove `lastSubmittedMsgRef.current = null` on stream complete - `getSendSuppressionReason` already has a two-condition check: `ref === text AND lastUserMsg === text` — so legitimate re-asks (after a different question was answered) still work; only rapid re-sends of the exact same text while it's still the last user message are blocked ## How - 30 s Redis TTL covers infrastructure retry windows (k8s SIGTERM → connection reset → ingress retry typically < 5 s) - Empty SSE response is well-formed (StreamFinish + [DONE]) — frontend AI SDK marks the turn complete without rendering a ghost message - Frontend ref kept live means: submit "foo" → success → submit "foo" again instantly → suppressed. Submit "foo" → success → submit "bar" → proceeds (different text updates the ref). ## Tests - 3 new backend route tests: duplicate blocked, first POST proceeds, non-user messages bypass - 5 new frontend `getSendSuppressionReason` unit tests: fresh ref, reconnecting, duplicate suppressed, different-turn re-ask allowed, different text allowed ## Checklist - [x] I have read the [AutoGPT Contributing Guide](https://github.com/Significant-Gravitas/AutoGPT/blob/master/CONTRIBUTING.md) - [x] I have performed a self-review of my code - [x] I have added tests that prove the fix is effective - [x] I have run `poetry run format` and `pnpm format` + `pnpm lint` --- .../backend/api/features/chat/routes.py | 149 ++++++---- .../backend/api/features/chat/routes_test.py | 278 +++++++++++++++++- .../backend/backend/copilot/message_dedup.py | 71 +++++ .../backend/copilot/message_dedup_test.py | 94 ++++++ .../copilot/__tests__/helpers.test.ts | 71 ++++- .../(platform)/copilot/useCopilotStream.ts | 7 +- 6 files changed, 606 insertions(+), 64 deletions(-) create mode 100644 autogpt_platform/backend/backend/copilot/message_dedup.py create mode 100644 autogpt_platform/backend/backend/copilot/message_dedup_test.py diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index ac7325e201..496e958e17 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -18,6 +18,7 @@ from backend.copilot import stream_registry from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn +from backend.copilot.message_dedup import acquire_dedup_lock from backend.copilot.model import ( ChatMessage, ChatSession, @@ -840,6 +841,9 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None + # Capture the original message text BEFORE any mutation (attachment enrichment) + # so the idempotency hash is stable across retries. + original_message = request.message if request.file_ids and user_id: # Filter to valid UUIDs only to prevent DB abuse valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] @@ -868,61 +872,91 @@ async def stream_chat_post( ) request.message += files_block + # ── Idempotency guard ──────────────────────────────────────────────────── + # Blocks duplicate executor tasks from concurrent/retried POSTs. + # See backend/copilot/message_dedup.py for the full lifecycle description. + dedup_lock = None + if request.is_user_message: + dedup_lock = await acquire_dedup_lock( + session_id, original_message, sanitized_file_ids + ) + if dedup_lock is None and (original_message or sanitized_file_ids): + + async def _empty_sse() -> AsyncGenerator[str, None]: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + + return StreamingResponse( + _empty_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't # saved yet. append_and_save_message re-fetches inside a lock to prevent # message loss from concurrent requests. - if request.message: - message = ChatMessage( - role="user" if request.is_user_message else "assistant", - content=request.message, - ) - if request.is_user_message: - track_user_message( - user_id=user_id, - session_id=session_id, - message_length=len(request.message), + # + # If any of these operations raises, release the dedup lock before propagating + # so subsequent retries are not blocked for 30 s. + try: + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, ) - logger.info(f"[STREAM] Saving user message to session {session_id}") - await append_and_save_message(session_id, message) - logger.info(f"[STREAM] User message saved for session {session_id}") + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") - # Create a task in the stream registry for reconnection support - turn_id = str(uuid4()) - log_meta["turn_id"] = turn_id + # Create a task in the stream registry for reconnection support + turn_id = str(uuid4()) + log_meta["turn_id"] = turn_id - session_create_start = time.perf_counter() - await stream_registry.create_session( - session_id=session_id, - user_id=user_id, - tool_call_id="chat_stream", - tool_name="chat", - turn_id=turn_id, - ) - logger.info( - f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", - extra={ - "json_fields": { - **log_meta, - "duration_ms": (time.perf_counter() - session_create_start) * 1000, - } - }, - ) + session_create_start = time.perf_counter() + await stream_registry.create_session( + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", + tool_name="chat", + turn_id=turn_id, + ) + logger.info( + f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "duration_ms": (time.perf_counter() - session_create_start) * 1000, + } + }, + ) - # Per-turn stream is always fresh (unique turn_id), subscribe from beginning - subscribe_from_id = "0-0" - - await enqueue_copilot_turn( - session_id=session_id, - user_id=user_id, - message=request.message, - turn_id=turn_id, - is_user_message=request.is_user_message, - context=request.context, - file_ids=sanitized_file_ids, - mode=request.mode, - model=request.model, - ) + await enqueue_copilot_turn( + session_id=session_id, + user_id=user_id, + message=request.message, + turn_id=turn_id, + is_user_message=request.is_user_message, + context=request.context, + file_ids=sanitized_file_ids, + mode=request.mode, + model=request.model, + ) + except Exception: + if dedup_lock: + await dedup_lock.release() + raise setup_time = (time.perf_counter() - stream_start_time) * 1000 logger.info( @@ -930,6 +964,9 @@ async def stream_chat_post( extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, ) + # Per-turn stream is always fresh (unique turn_id), subscribe from beginning + subscribe_from_id = "0-0" + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: import time as time_module @@ -943,6 +980,12 @@ async def stream_chat_post( subscriber_queue = None first_chunk_yielded = False chunks_yielded = 0 + # True for every exit path except GeneratorExit (client disconnect). + # On disconnect the backend turn is still running — releasing the lock + # there would reopen the infra-retry duplicate window. The 30 s TTL + # is the fallback. All other exits (normal finish, early return, error) + # should release so the user can re-send the same message. + release_dedup_lock_on_exit = True try: # Subscribe from the position we captured before enqueuing # This avoids replaying old messages while catching all new ones @@ -954,8 +997,7 @@ async def stream_chat_post( if subscriber_queue is None: yield StreamFinish().to_sse() - yield "data: [DONE]\n\n" - return + return # finally releases dedup_lock # Read from the subscriber queue and yield to SSE logger.info( @@ -984,7 +1026,6 @@ async def stream_chat_post( yield chunk.to_sse() - # Check for finish signal if isinstance(chunk, StreamFinish): total_time = time_module.perf_counter() - event_gen_start logger.info( @@ -998,7 +1039,8 @@ async def stream_chat_post( } }, ) - break + break # finally releases dedup_lock + except asyncio.TimeoutError: yield StreamHeartbeat().to_sse() @@ -1013,7 +1055,7 @@ async def stream_chat_post( } }, ) - pass # Client disconnected - background task continues + release_dedup_lock_on_exit = False except Exception as e: elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error( @@ -1028,7 +1070,10 @@ async def stream_chat_post( code="stream_error", ).to_sse() yield StreamFinish().to_sse() + # finally releases dedup_lock finally: + if dedup_lock and release_dedup_lock_on_exit: + await dedup_lock.release() # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index 74259b3463..597aad01ad 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids(): assert response.status_code == 422 -def _mock_stream_internals(mocker: pytest_mock.MockFixture): +def _mock_stream_internals( + mocker: pytest_mock.MockerFixture, + *, + redis_set_returns: object = True, +): """Mock the async internals of stream_chat_post so tests can exercise - validation and enrichment logic without needing Redis/RabbitMQ.""" + validation and enrichment logic without needing Redis/RabbitMQ. + + Args: + redis_set_returns: Value returned by the mocked Redis ``set`` call. + ``True`` (default) simulates a fresh key (new message); + ``None`` simulates a collision (duplicate blocked). + + Returns: + A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so + callers can make additional assertions about side-effects. + """ + import types + mocker.patch( "backend.api.features.chat.routes._validate_and_get_session", return_value=None, ) - mocker.patch( + mock_save = mocker.patch( "backend.api.features.chat.routes.append_and_save_message", return_value=None, ) @@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.stream_registry", mock_registry, ) - mocker.patch( + mock_enqueue = mocker.patch( "backend.api.features.chat.routes.enqueue_copilot_turn", return_value=None, ) @@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.track_user_message", return_value=None, ) + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=redis_set_returns) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue) + return ns -def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): +def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture): """Exactly 20 file_ids should be accepted (not rejected by validation).""" _mock_stream_internals(mocker) # Patch workspace lookup as imported by the routes module @@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): # ─── UUID format filtering ───────────────────────────────────────────── -def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): +def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture): """Non-UUID strings in file_ids should be silently filtered out and NOT passed to the database query.""" _mock_stream_internals(mocker) @@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): # ─── Cross-workspace file_ids ───────────────────────────────────────── -def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): +def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture): """The batch query should scope to the user's workspace.""" _mock_stream_internals(mocker) mocker.patch( @@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): # ─── Rate limit → 429 ───────────────────────────────────────────────── -def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture): """When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix assert "daily" in response.json()["detail"].lower() -def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_weekly_rate_limit( + mocker: pytest_mock.MockerFixture, +): """When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi assert "resets in" in detail -def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture): +def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): """The 429 response detail should include the human-readable reset time.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -679,6 +706,237 @@ class TestStripInjectedContext: assert result["content"] == "hello" +# ─── Idempotency / duplicate-POST guard ────────────────────────────── + + +def test_stream_chat_blocks_duplicate_post_returns_empty_sse( + mocker: pytest_mock.MockerFixture, +) -> None: + """A second POST with the same message within the 30-s window must return + an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the + turn complete without creating a ghost response.""" + # redis_set_returns=None simulates a collision: the NX key already exists. + ns = _mock_stream_internals(mocker, redis_set_returns=None) + + response = client.post( + "/sessions/sess-dup/stream", + json={"message": "duplicate message", "is_user_message": True}, + ) + + assert response.status_code == 200 + body = response.text + # The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator. + assert '"finish"' in body + assert "[DONE]" in body + # The empty SSE response must include the AI SDK protocol header so the + # frontend treats it as a valid stream and marks the turn complete. + assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1" + # The duplicate guard must prevent save/enqueue side effects. + ns.save.assert_not_called() + ns.enqueue.assert_not_called() + + +def test_stream_chat_first_post_proceeds_normally( + mocker: pytest_mock.MockerFixture, +) -> None: + """The first POST (Redis NX key set successfully) must proceed through the + normal streaming path — no early return.""" + ns = _mock_stream_internals(mocker, redis_set_returns=True) + + response = client.post( + "/sessions/sess-new/stream", + json={"message": "first message", "is_user_message": True}, + ) + + assert response.status_code == 200 + # Redis set must have been called once with the NX flag. + ns.redis.set.assert_called_once() + call_kwargs = ns.redis.set.call_args + assert call_kwargs.kwargs.get("nx") is True + + +def test_stream_chat_dedup_skipped_for_non_user_messages( + mocker: pytest_mock.MockerFixture, +) -> None: + """System/assistant messages (is_user_message=False) bypass the dedup + guard — they are injected programmatically and must always be processed.""" + ns = _mock_stream_internals(mocker, redis_set_returns=None) + + response = client.post( + "/sessions/sess-sys/stream", + json={"message": "system context", "is_user_message": False}, + ) + + # Even though redis_set_returns=None (would block a user message), + # the endpoint must proceed because is_user_message=False. + assert response.status_code == 200 + ns.redis.set.assert_not_called() + + +def test_stream_chat_dedup_hash_uses_original_message_not_mutated( + mocker: pytest_mock.MockerFixture, +) -> None: + """The dedup hash must be computed from the original request message, + not the mutated version that has the [Attached files] block appended. + A file_id is sent so the route actually appends the [Attached files] block, + exercising the mutation path — the hash must still match the original text.""" + import hashlib + + ns = _mock_stream_internals(mocker, redis_set_returns=True) + + file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + # Mock workspace + prisma so the attachment block is actually appended. + mocker.patch( + "backend.api.features.chat.routes.get_or_create_workspace", + return_value=type("W", (), {"id": "ws-1"})(), + ) + fake_file = type( + "F", + (), + { + "id": file_id, + "name": "doc.pdf", + "mimeType": "application/pdf", + "sizeBytes": 1024, + }, + )() + mock_prisma = mocker.MagicMock() + mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file]) + mocker.patch( + "prisma.models.UserWorkspaceFile.prisma", + return_value=mock_prisma, + ) + + response = client.post( + "/sessions/sess-hash/stream", + json={ + "message": "plain message", + "is_user_message": True, + "file_ids": [file_id], + }, + ) + + assert response.status_code == 200 + ns.redis.set.assert_called_once() + call_args = ns.redis.set.call_args + dedup_key = call_args.args[0] + + # Hash must use the original message + sorted file IDs, not the mutated text. + expected_hash = hashlib.sha256( + f"sess-hash:plain message:{file_id}".encode() + ).hexdigest()[:16] + expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}" + assert dedup_key == expected_key, ( + f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — " + "hash may be using mutated message or wrong inputs" + ) + + +def test_stream_chat_dedup_key_released_after_stream_finish( + mocker: pytest_mock.MockerFixture, +) -> None: + """The dedup Redis key must be deleted after the turn completes (when + subscriber_queue is None the route yields StreamFinish immediately and + should release the key so the user can re-send the same message).""" + from unittest.mock import AsyncMock as _AsyncMock + + # Set up all internals manually so we can control subscribe_to_session. + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.append_and_save_message", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.enqueue_copilot_turn", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.track_user_message", + return_value=None, + ) + mock_registry = mocker.MagicMock() + mock_registry.create_session = _AsyncMock(return_value=None) + # None → early-finish path: StreamFinish yielded immediately, dedup key released. + mock_registry.subscribe_to_session = _AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.stream_registry", + mock_registry, + ) + mock_redis = mocker.AsyncMock() + mock_redis.set = _AsyncMock(return_value=True) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=_AsyncMock, + return_value=mock_redis, + ) + + response = client.post( + "/sessions/sess-finish/stream", + json={"message": "hello", "is_user_message": True}, + ) + + assert response.status_code == 200 + body = response.text + assert '"finish"' in body + # The dedup key must be released so intentional re-sends are allowed. + mock_redis.delete.assert_called_once() + + +def test_stream_chat_dedup_key_released_even_when_redis_delete_raises( + mocker: pytest_mock.MockerFixture, +) -> None: + """The route must not crash when the dedup Redis delete fails on the + subscriber_queue-is-None early-finish path (except Exception: pass).""" + from unittest.mock import AsyncMock as _AsyncMock + + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.append_and_save_message", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.enqueue_copilot_turn", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.track_user_message", + return_value=None, + ) + mock_registry = mocker.MagicMock() + mock_registry.create_session = _AsyncMock(return_value=None) + mock_registry.subscribe_to_session = _AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.stream_registry", + mock_registry, + ) + mock_redis = mocker.AsyncMock() + mock_redis.set = _AsyncMock(return_value=True) + # Make the delete raise so the except-pass branch is exercised. + mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone")) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=_AsyncMock, + return_value=mock_redis, + ) + + # Should not raise even though delete fails. + response = client.post( + "/sessions/sess-finish-err/stream", + json={"message": "hello", "is_user_message": True}, + ) + + assert response.status_code == 200 + assert '"finish"' in response.text + # delete must have been attempted — the except-pass branch silenced the error. + mock_redis.delete.assert_called_once() + + # ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── diff --git a/autogpt_platform/backend/backend/copilot/message_dedup.py b/autogpt_platform/backend/backend/copilot/message_dedup.py new file mode 100644 index 0000000000..2af13b559a --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/message_dedup.py @@ -0,0 +1,71 @@ +"""Per-request idempotency lock for the /stream endpoint. + +Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s +rolling-deploy retries, nginx upstream retries, rapid double-clicks). + +Lifecycle +--------- +1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids) + and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or + ``None`` when the key already exists (duplicate request). +2. ``release()`` — deletes the key. Must be called on turn completion or turn + error so the next legitimate send is never blocked. +3. On client disconnect (``GeneratorExit``) the lock must NOT be released — + the backend turn is still running, and releasing would reopen the duplicate + window for infra-level retries. The 30 s TTL is the safety net. +""" + +import hashlib +import logging + +from backend.data.redis_client import get_redis_async + +logger = logging.getLogger(__name__) + +_KEY_PREFIX = "chat:msg_dedup" +_TTL_SECONDS = 30 + + +class _DedupLock: + def __init__(self, key: str, redis) -> None: + self._key = key + self._redis = redis + + async def release(self) -> None: + """Best-effort key deletion. The TTL handles failures silently.""" + try: + await self._redis.delete(self._key) + except Exception: + pass + + +async def acquire_dedup_lock( + session_id: str, + message: str | None, + file_ids: list[str] | None, +) -> _DedupLock | None: + """Acquire the idempotency lock for this (session, message, files) tuple. + + Returns a ``_DedupLock`` when the lock is freshly acquired (first request). + Returns ``None`` when a duplicate is detected (lock already held). + Returns ``None`` when there is nothing to deduplicate (no message, no files). + """ + if not message and not file_ids: + return None + + sorted_ids = ":".join(sorted(file_ids or [])) + content_hash = hashlib.sha256( + f"{session_id}:{message or ''}:{sorted_ids}".encode() + ).hexdigest()[:16] + key = f"{_KEY_PREFIX}:{session_id}:{content_hash}" + + redis = await get_redis_async() + acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True) + if not acquired: + logger.warning( + f"[STREAM] Duplicate user message blocked for session {session_id}, " + f"hash={content_hash} — returning empty SSE", + ) + return None + + return _DedupLock(key, redis) diff --git a/autogpt_platform/backend/backend/copilot/message_dedup_test.py b/autogpt_platform/backend/backend/copilot/message_dedup_test.py new file mode 100644 index 0000000000..935ddd36b6 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/message_dedup_test.py @@ -0,0 +1,94 @@ +"""Unit tests for backend.copilot.message_dedup.""" + +from unittest.mock import AsyncMock + +import pytest +import pytest_mock + +from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock + + +def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns): + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=set_returns) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + return mock_redis + + +@pytest.mark.asyncio +async def test_acquire_returns_none_when_no_message_no_files( + mocker: pytest_mock.MockerFixture, +) -> None: + """Nothing to deduplicate — no Redis call made, None returned.""" + mock_redis = _patch_redis(mocker, set_returns=True) + result = await acquire_dedup_lock("sess-1", None, None) + assert result is None + mock_redis.set.assert_not_called() + + +@pytest.mark.asyncio +async def test_acquire_returns_lock_on_first_request( + mocker: pytest_mock.MockerFixture, +) -> None: + """First request acquires the lock and returns a _DedupLock.""" + mock_redis = _patch_redis(mocker, set_returns=True) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + mock_redis.set.assert_called_once() + key_arg = mock_redis.set.call_args.args[0] + assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:") + + +@pytest.mark.asyncio +async def test_acquire_returns_none_on_duplicate( + mocker: pytest_mock.MockerFixture, +) -> None: + """Duplicate request (NX fails) returns None to signal the caller.""" + _patch_redis(mocker, set_returns=None) + result = await acquire_dedup_lock("sess-1", "hello", None) + assert result is None + + +@pytest.mark.asyncio +async def test_acquire_key_stable_across_file_order( + mocker: pytest_mock.MockerFixture, +) -> None: + """File IDs are sorted before hashing so order doesn't affect the key.""" + mock_redis_1 = _patch_redis(mocker, set_returns=True) + await acquire_dedup_lock("sess-1", "msg", ["b", "a"]) + key_ab = mock_redis_1.set.call_args.args[0] + + mock_redis_2 = _patch_redis(mocker, set_returns=True) + await acquire_dedup_lock("sess-1", "msg", ["a", "b"]) + key_ba = mock_redis_2.set.call_args.args[0] + + assert key_ab == key_ba + + +@pytest.mark.asyncio +async def test_release_deletes_key( + mocker: pytest_mock.MockerFixture, +) -> None: + """release() calls Redis delete exactly once.""" + mock_redis = _patch_redis(mocker, set_returns=True) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + await lock.release() + mock_redis.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_release_swallows_redis_error( + mocker: pytest_mock.MockerFixture, +) -> None: + """release() must not raise even when Redis delete fails.""" + mock_redis = _patch_redis(mocker, set_returns=True) + mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down")) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + await lock.release() # must not raise + mock_redis.delete.assert_called_once() diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts index 712aaaf508..9580ef349a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts @@ -1,6 +1,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { IMPERSONATION_HEADER_NAME } from "@/lib/constants"; -import { getCopilotAuthHeaders } from "../helpers"; +import { getCopilotAuthHeaders, getSendSuppressionReason } from "../helpers"; +import type { UIMessage } from "ai"; vi.mock("@/lib/supabase/actions", () => ({ getWebSocketToken: vi.fn(), @@ -72,3 +73,71 @@ describe("getCopilotAuthHeaders", () => { ); }); }); + +// ─── getSendSuppressionReason ───────────────────────────────────────────────── + +function makeUserMsg(text: string): UIMessage { + return { + id: "msg-1", + role: "user", + content: text, + parts: [{ type: "text", text }], + } as UIMessage; +} + +describe("getSendSuppressionReason", () => { + it("returns null when no dedup context exists (fresh ref)", () => { + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: null, + messages: [], + }); + expect(result).toBeNull(); + }); + + it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => { + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: true, + lastSubmittedText: null, + messages: [], + }); + expect(result).toBe("reconnecting"); + }); + + it("returns 'duplicate' when same text was submitted and is the last user message", () => { + // This is the core regression test: after a successful turn the ref + // is intentionally NOT cleared to null, so submitting the same text + // again is caught here. + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: "hello", + messages: [makeUserMsg("hello")], + }); + expect(result).toBe("duplicate"); + }); + + it("returns null when same ref text but different last user message (different question)", () => { + // User asked "hello" before, got a reply, then asked a different question + // — the last user message in chat is now different, so no suppression. + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: "hello", + messages: [makeUserMsg("hello"), makeUserMsg("something else")], + }); + expect(result).toBeNull(); + }); + + it("returns null when text differs from lastSubmittedText", () => { + const result = getSendSuppressionReason({ + text: "new question", + isReconnectScheduled: false, + lastSubmittedText: "old question", + messages: [makeUserMsg("old question")], + }); + expect(result).toBeNull(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 85709f23d9..666b87bfba 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -497,7 +497,12 @@ export function useCopilotStream({ if (status === "ready") { reconnectAttemptsRef.current = 0; hasShownDisconnectToast.current = false; - lastSubmittedMsgRef.current = null; + // Intentionally NOT clearing lastSubmittedMsgRef here: keeping the last + // submitted text prevents getSendSuppressionReason from allowing a + // duplicate POST of the same message immediately after a successful turn + // (the "duplicate" branch checks both the ref and the visible last user + // message, so legitimate re-sends after a different reply are still + // allowed). setReconnectExhausted(false); } }