diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index c09ef9fe76..7d0521cb81 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -15,9 +15,10 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from backend.copilot import service as chat_service from backend.copilot import stream_registry -from backend.copilot.config import ChatConfig, CopilotMode +from backend.copilot.config import ChatConfig, CopilotLlmModel, CopilotMode from backend.copilot.db import get_chat_messages_paginated from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn +from backend.copilot.message_dedup import acquire_dedup_lock from backend.copilot.model import ( ChatMessage, ChatSession, @@ -140,6 +141,11 @@ class StreamChatRequest(BaseModel): description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. " "If None, uses the server default (extended_thinking).", ) + model: CopilotLlmModel | None = Field( + default=None, + description="Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. " + "If None, the server applies per-user LD targeting then falls back to config.", + ) class CreateSessionRequest(BaseModel): @@ -377,6 +383,31 @@ async def delete_session( return Response(status_code=204) +@router.delete( + "/sessions/{session_id}/stream", + dependencies=[Security(auth.requires_user)], + status_code=204, +) +async def disconnect_session_stream( + session_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> Response: + """Disconnect all active SSE listeners for a session. + + Called by the frontend when the user switches away from a chat so the + backend releases XREAD listeners immediately rather than waiting for + the 5-10 s timeout. + """ + session = await get_chat_session(session_id, user_id) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session {session_id} not found or access denied", + ) + await stream_registry.disconnect_all_listeners(session_id) + return Response(status_code=204) + + @router.patch( "/sessions/{session_id}/title", summary="Update session title", @@ -811,6 +842,9 @@ async def stream_chat_post( # Also sanitise file_ids so only validated, workspace-scoped IDs are # forwarded downstream (e.g. to the executor via enqueue_copilot_turn). sanitized_file_ids: list[str] | None = None + # Capture the original message text BEFORE any mutation (attachment enrichment) + # so the idempotency hash is stable across retries. + original_message = request.message if request.file_ids and user_id: # Filter to valid UUIDs only to prevent DB abuse valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)] @@ -839,60 +873,91 @@ async def stream_chat_post( ) request.message += files_block + # ── Idempotency guard ──────────────────────────────────────────────────── + # Blocks duplicate executor tasks from concurrent/retried POSTs. + # See backend/copilot/message_dedup.py for the full lifecycle description. + dedup_lock = None + if request.is_user_message: + dedup_lock = await acquire_dedup_lock( + session_id, original_message, sanitized_file_ids + ) + if dedup_lock is None and (original_message or sanitized_file_ids): + + async def _empty_sse() -> AsyncGenerator[str, None]: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + + return StreamingResponse( + _empty_sse(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + # Atomically append user message to session BEFORE creating task to avoid # race condition where GET_SESSION sees task as "running" but message isn't # saved yet. append_and_save_message re-fetches inside a lock to prevent # message loss from concurrent requests. - if request.message: - message = ChatMessage( - role="user" if request.is_user_message else "assistant", - content=request.message, - ) - if request.is_user_message: - track_user_message( - user_id=user_id, - session_id=session_id, - message_length=len(request.message), + # + # If any of these operations raises, release the dedup lock before propagating + # so subsequent retries are not blocked for 30 s. + try: + if request.message: + message = ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, ) - logger.info(f"[STREAM] Saving user message to session {session_id}") - await append_and_save_message(session_id, message) - logger.info(f"[STREAM] User message saved for session {session_id}") + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info(f"[STREAM] Saving user message to session {session_id}") + await append_and_save_message(session_id, message) + logger.info(f"[STREAM] User message saved for session {session_id}") - # Create a task in the stream registry for reconnection support - turn_id = str(uuid4()) - log_meta["turn_id"] = turn_id + # Create a task in the stream registry for reconnection support + turn_id = str(uuid4()) + log_meta["turn_id"] = turn_id - session_create_start = time.perf_counter() - await stream_registry.create_session( - session_id=session_id, - user_id=user_id, - tool_call_id="chat_stream", - tool_name="chat", - turn_id=turn_id, - ) - logger.info( - f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", - extra={ - "json_fields": { - **log_meta, - "duration_ms": (time.perf_counter() - session_create_start) * 1000, - } - }, - ) + session_create_start = time.perf_counter() + await stream_registry.create_session( + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", + tool_name="chat", + turn_id=turn_id, + ) + logger.info( + f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms", + extra={ + "json_fields": { + **log_meta, + "duration_ms": (time.perf_counter() - session_create_start) * 1000, + } + }, + ) - # Per-turn stream is always fresh (unique turn_id), subscribe from beginning - subscribe_from_id = "0-0" - - await enqueue_copilot_turn( - session_id=session_id, - user_id=user_id, - message=request.message, - turn_id=turn_id, - is_user_message=request.is_user_message, - context=request.context, - file_ids=sanitized_file_ids, - mode=request.mode, - ) + await enqueue_copilot_turn( + session_id=session_id, + user_id=user_id, + message=request.message, + turn_id=turn_id, + is_user_message=request.is_user_message, + context=request.context, + file_ids=sanitized_file_ids, + mode=request.mode, + model=request.model, + ) + except Exception: + if dedup_lock: + await dedup_lock.release() + raise setup_time = (time.perf_counter() - stream_start_time) * 1000 logger.info( @@ -900,6 +965,9 @@ async def stream_chat_post( extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, ) + # Per-turn stream is always fresh (unique turn_id), subscribe from beginning + subscribe_from_id = "0-0" + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: import time as time_module @@ -913,6 +981,12 @@ async def stream_chat_post( subscriber_queue = None first_chunk_yielded = False chunks_yielded = 0 + # True for every exit path except GeneratorExit (client disconnect). + # On disconnect the backend turn is still running — releasing the lock + # there would reopen the infra-retry duplicate window. The 30 s TTL + # is the fallback. All other exits (normal finish, early return, error) + # should release so the user can re-send the same message. + release_dedup_lock_on_exit = True try: # Subscribe from the position we captured before enqueuing # This avoids replaying old messages while catching all new ones @@ -924,8 +998,7 @@ async def stream_chat_post( if subscriber_queue is None: yield StreamFinish().to_sse() - yield "data: [DONE]\n\n" - return + return # finally releases dedup_lock # Read from the subscriber queue and yield to SSE logger.info( @@ -954,7 +1027,6 @@ async def stream_chat_post( yield chunk.to_sse() - # Check for finish signal if isinstance(chunk, StreamFinish): total_time = time_module.perf_counter() - event_gen_start logger.info( @@ -968,7 +1040,8 @@ async def stream_chat_post( } }, ) - break + break # finally releases dedup_lock + except asyncio.TimeoutError: yield StreamHeartbeat().to_sse() @@ -983,7 +1056,7 @@ async def stream_chat_post( } }, ) - pass # Client disconnected - background task continues + release_dedup_lock_on_exit = False except Exception as e: elapsed = (time_module.perf_counter() - event_gen_start) * 1000 logger.error( @@ -998,7 +1071,10 @@ async def stream_chat_post( code="stream_error", ).to_sse() yield StreamFinish().to_sse() + # finally releases dedup_lock finally: + if dedup_lock and release_dedup_lock_on_exit: + await dedup_lock.release() # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index f3896c7098..597aad01ad 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -133,14 +133,30 @@ def test_stream_chat_rejects_too_many_file_ids(): assert response.status_code == 422 -def _mock_stream_internals(mocker: pytest_mock.MockFixture): +def _mock_stream_internals( + mocker: pytest_mock.MockerFixture, + *, + redis_set_returns: object = True, +): """Mock the async internals of stream_chat_post so tests can exercise - validation and enrichment logic without needing Redis/RabbitMQ.""" + validation and enrichment logic without needing Redis/RabbitMQ. + + Args: + redis_set_returns: Value returned by the mocked Redis ``set`` call. + ``True`` (default) simulates a fresh key (new message); + ``None`` simulates a collision (duplicate blocked). + + Returns: + A namespace with ``redis``, ``save``, and ``enqueue`` mock objects so + callers can make additional assertions about side-effects. + """ + import types + mocker.patch( "backend.api.features.chat.routes._validate_and_get_session", return_value=None, ) - mocker.patch( + mock_save = mocker.patch( "backend.api.features.chat.routes.append_and_save_message", return_value=None, ) @@ -150,7 +166,7 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.stream_registry", mock_registry, ) - mocker.patch( + mock_enqueue = mocker.patch( "backend.api.features.chat.routes.enqueue_copilot_turn", return_value=None, ) @@ -158,9 +174,18 @@ def _mock_stream_internals(mocker: pytest_mock.MockFixture): "backend.api.features.chat.routes.track_user_message", return_value=None, ) + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=redis_set_returns) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + ns = types.SimpleNamespace(redis=mock_redis, save=mock_save, enqueue=mock_enqueue) + return ns -def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): +def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockerFixture): """Exactly 20 file_ids should be accepted (not rejected by validation).""" _mock_stream_internals(mocker) # Patch workspace lookup as imported by the routes module @@ -189,7 +214,7 @@ def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture): # ─── UUID format filtering ───────────────────────────────────────────── -def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): +def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockerFixture): """Non-UUID strings in file_ids should be silently filtered out and NOT passed to the database query.""" _mock_stream_internals(mocker) @@ -228,7 +253,7 @@ def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture): # ─── Cross-workspace file_ids ───────────────────────────────────────── -def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): +def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockerFixture): """The batch query should scope to the user's workspace.""" _mock_stream_internals(mocker) mocker.patch( @@ -257,7 +282,7 @@ def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture): # ─── Rate limit → 429 ───────────────────────────────────────────────── -def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockerFixture): """When check_rate_limit raises RateLimitExceeded for daily limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -278,7 +303,9 @@ def test_stream_chat_returns_429_on_daily_rate_limit(mocker: pytest_mock.MockFix assert "daily" in response.json()["detail"].lower() -def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFixture): +def test_stream_chat_returns_429_on_weekly_rate_limit( + mocker: pytest_mock.MockerFixture, +): """When check_rate_limit raises RateLimitExceeded for weekly limit the endpoint returns 429.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -301,7 +328,7 @@ def test_stream_chat_returns_429_on_weekly_rate_limit(mocker: pytest_mock.MockFi assert "resets in" in detail -def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockFixture): +def test_stream_chat_429_includes_reset_time(mocker: pytest_mock.MockerFixture): """The 429 response detail should include the human-readable reset time.""" from backend.copilot.rate_limit import RateLimitExceeded @@ -677,3 +704,279 @@ class TestStripInjectedContext: result = _strip_injected_context(msg) # Without a role, the helper short-circuits without touching content. assert result["content"] == "hello" + + +# ─── Idempotency / duplicate-POST guard ────────────────────────────── + + +def test_stream_chat_blocks_duplicate_post_returns_empty_sse( + mocker: pytest_mock.MockerFixture, +) -> None: + """A second POST with the same message within the 30-s window must return + an empty SSE stream (StreamFinish + [DONE]) so the frontend marks the + turn complete without creating a ghost response.""" + # redis_set_returns=None simulates a collision: the NX key already exists. + ns = _mock_stream_internals(mocker, redis_set_returns=None) + + response = client.post( + "/sessions/sess-dup/stream", + json={"message": "duplicate message", "is_user_message": True}, + ) + + assert response.status_code == 200 + body = response.text + # The response must contain StreamFinish (type=finish) and the SSE [DONE] terminator. + assert '"finish"' in body + assert "[DONE]" in body + # The empty SSE response must include the AI SDK protocol header so the + # frontend treats it as a valid stream and marks the turn complete. + assert response.headers.get("x-vercel-ai-ui-message-stream") == "v1" + # The duplicate guard must prevent save/enqueue side effects. + ns.save.assert_not_called() + ns.enqueue.assert_not_called() + + +def test_stream_chat_first_post_proceeds_normally( + mocker: pytest_mock.MockerFixture, +) -> None: + """The first POST (Redis NX key set successfully) must proceed through the + normal streaming path — no early return.""" + ns = _mock_stream_internals(mocker, redis_set_returns=True) + + response = client.post( + "/sessions/sess-new/stream", + json={"message": "first message", "is_user_message": True}, + ) + + assert response.status_code == 200 + # Redis set must have been called once with the NX flag. + ns.redis.set.assert_called_once() + call_kwargs = ns.redis.set.call_args + assert call_kwargs.kwargs.get("nx") is True + + +def test_stream_chat_dedup_skipped_for_non_user_messages( + mocker: pytest_mock.MockerFixture, +) -> None: + """System/assistant messages (is_user_message=False) bypass the dedup + guard — they are injected programmatically and must always be processed.""" + ns = _mock_stream_internals(mocker, redis_set_returns=None) + + response = client.post( + "/sessions/sess-sys/stream", + json={"message": "system context", "is_user_message": False}, + ) + + # Even though redis_set_returns=None (would block a user message), + # the endpoint must proceed because is_user_message=False. + assert response.status_code == 200 + ns.redis.set.assert_not_called() + + +def test_stream_chat_dedup_hash_uses_original_message_not_mutated( + mocker: pytest_mock.MockerFixture, +) -> None: + """The dedup hash must be computed from the original request message, + not the mutated version that has the [Attached files] block appended. + A file_id is sent so the route actually appends the [Attached files] block, + exercising the mutation path — the hash must still match the original text.""" + import hashlib + + ns = _mock_stream_internals(mocker, redis_set_returns=True) + + file_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + # Mock workspace + prisma so the attachment block is actually appended. + mocker.patch( + "backend.api.features.chat.routes.get_or_create_workspace", + return_value=type("W", (), {"id": "ws-1"})(), + ) + fake_file = type( + "F", + (), + { + "id": file_id, + "name": "doc.pdf", + "mimeType": "application/pdf", + "sizeBytes": 1024, + }, + )() + mock_prisma = mocker.MagicMock() + mock_prisma.find_many = mocker.AsyncMock(return_value=[fake_file]) + mocker.patch( + "prisma.models.UserWorkspaceFile.prisma", + return_value=mock_prisma, + ) + + response = client.post( + "/sessions/sess-hash/stream", + json={ + "message": "plain message", + "is_user_message": True, + "file_ids": [file_id], + }, + ) + + assert response.status_code == 200 + ns.redis.set.assert_called_once() + call_args = ns.redis.set.call_args + dedup_key = call_args.args[0] + + # Hash must use the original message + sorted file IDs, not the mutated text. + expected_hash = hashlib.sha256( + f"sess-hash:plain message:{file_id}".encode() + ).hexdigest()[:16] + expected_key = f"chat:msg_dedup:sess-hash:{expected_hash}" + assert dedup_key == expected_key, ( + f"Dedup key {dedup_key!r} does not match expected {expected_key!r} — " + "hash may be using mutated message or wrong inputs" + ) + + +def test_stream_chat_dedup_key_released_after_stream_finish( + mocker: pytest_mock.MockerFixture, +) -> None: + """The dedup Redis key must be deleted after the turn completes (when + subscriber_queue is None the route yields StreamFinish immediately and + should release the key so the user can re-send the same message).""" + from unittest.mock import AsyncMock as _AsyncMock + + # Set up all internals manually so we can control subscribe_to_session. + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.append_and_save_message", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.enqueue_copilot_turn", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.track_user_message", + return_value=None, + ) + mock_registry = mocker.MagicMock() + mock_registry.create_session = _AsyncMock(return_value=None) + # None → early-finish path: StreamFinish yielded immediately, dedup key released. + mock_registry.subscribe_to_session = _AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.stream_registry", + mock_registry, + ) + mock_redis = mocker.AsyncMock() + mock_redis.set = _AsyncMock(return_value=True) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=_AsyncMock, + return_value=mock_redis, + ) + + response = client.post( + "/sessions/sess-finish/stream", + json={"message": "hello", "is_user_message": True}, + ) + + assert response.status_code == 200 + body = response.text + assert '"finish"' in body + # The dedup key must be released so intentional re-sends are allowed. + mock_redis.delete.assert_called_once() + + +def test_stream_chat_dedup_key_released_even_when_redis_delete_raises( + mocker: pytest_mock.MockerFixture, +) -> None: + """The route must not crash when the dedup Redis delete fails on the + subscriber_queue-is-None early-finish path (except Exception: pass).""" + from unittest.mock import AsyncMock as _AsyncMock + + mocker.patch( + "backend.api.features.chat.routes._validate_and_get_session", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.append_and_save_message", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.enqueue_copilot_turn", + return_value=None, + ) + mocker.patch( + "backend.api.features.chat.routes.track_user_message", + return_value=None, + ) + mock_registry = mocker.MagicMock() + mock_registry.create_session = _AsyncMock(return_value=None) + mock_registry.subscribe_to_session = _AsyncMock(return_value=None) + mocker.patch( + "backend.api.features.chat.routes.stream_registry", + mock_registry, + ) + mock_redis = mocker.AsyncMock() + mock_redis.set = _AsyncMock(return_value=True) + # Make the delete raise so the except-pass branch is exercised. + mock_redis.delete = _AsyncMock(side_effect=RuntimeError("redis gone")) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=_AsyncMock, + return_value=mock_redis, + ) + + # Should not raise even though delete fails. + response = client.post( + "/sessions/sess-finish-err/stream", + json={"message": "hello", "is_user_message": True}, + ) + + assert response.status_code == 200 + assert '"finish"' in response.text + # delete must have been attempted — the except-pass branch silenced the error. + mock_redis.delete.assert_called_once() + + +# ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── + + +def test_disconnect_stream_returns_204_and_awaits_registry( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mock_session = MagicMock() + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.delete("/sessions/sess-1/stream") + + assert response.status_code == 204 + mock_disconnect.assert_awaited_once_with("sess-1") + + +def test_disconnect_stream_returns_404_when_session_missing( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=None, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + ) + + response = client.delete("/sessions/unknown-session/stream") + + assert response.status_code == 404 + mock_disconnect.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index cfbc6feef4..d5418bf872 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -16,6 +16,13 @@ from backend.util.clients import OPENROUTER_BASE_URL # subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk. CopilotMode = Literal["fast", "extended_thinking"] +# Per-request model tier set by the frontend model toggle. +# 'standard' uses the global config default (currently Sonnet). +# 'advanced' forces the highest-capability model (currently Opus). +# None means no preference — falls through to LD per-user targeting, then config. +# Using tier names instead of model names keeps the contract model-agnostic. +CopilotLlmModel = Literal["standard", "advanced"] + class ChatConfig(BaseSettings): """Configuration for the chat system.""" @@ -163,12 +170,12 @@ class ChatConfig(BaseSettings): "CHAT_CLAUDE_AGENT_MAX_TURNS env var if your workflows need more.", ) claude_agent_max_budget_usd: float = Field( - default=15.0, + default=10.0, ge=0.01, le=1000.0, description="Maximum spend in USD per SDK query. The CLI attempts " "to wrap up gracefully when this budget is reached. " - "Set to $15 to allow most tasks to complete (p50=$5.37, p75=$13.07). " + "Set to $10 to allow most tasks to complete (p50=$5.37, p75=$13.07). " "Override via CHAT_CLAUDE_AGENT_MAX_BUDGET_USD env var.", ) claude_agent_max_thinking_tokens: int = Field( diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index cc83b2dd99..0266e57806 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -351,6 +351,7 @@ class CoPilotProcessor: context=entry.context, file_ids=entry.file_ids, mode=effective_mode, + model=entry.model, ) async for chunk in stream_registry.stream_and_publish( session_id=entry.session_id, diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py index 0f7d23d9ba..3256f94869 100644 --- a/autogpt_platform/backend/backend/copilot/executor/utils.py +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -9,7 +9,7 @@ import logging from pydantic import BaseModel -from backend.copilot.config import CopilotMode +from backend.copilot.config import CopilotLlmModel, CopilotMode from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig from backend.util.logging import TruncatedLogger, is_structured_logging_enabled @@ -160,6 +160,9 @@ class CoPilotExecutionEntry(BaseModel): mode: CopilotMode | None = None """Autopilot mode override: 'fast' or 'extended_thinking'. None = server default.""" + model: CopilotLlmModel | None = None + """Per-request model tier: 'standard' or 'advanced'. None = server default.""" + class CancelCoPilotEvent(BaseModel): """Event to cancel a CoPilot operation.""" @@ -180,6 +183,7 @@ async def enqueue_copilot_turn( context: dict[str, str] | None = None, file_ids: list[str] | None = None, mode: CopilotMode | None = None, + model: CopilotLlmModel | None = None, ) -> None: """Enqueue a CoPilot task for processing by the executor service. @@ -192,6 +196,7 @@ async def enqueue_copilot_turn( context: Optional context for the message (e.g., {url: str, content: str}) file_ids: Optional workspace file IDs attached to the user's message mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default. + model: Per-request model tier ('standard' or 'advanced'). None = server default. """ from backend.util.clients import get_async_copilot_queue @@ -204,6 +209,7 @@ async def enqueue_copilot_turn( context=context, file_ids=file_ids, mode=mode, + model=model, ) queue_client = await get_async_copilot_queue() diff --git a/autogpt_platform/backend/backend/copilot/message_dedup.py b/autogpt_platform/backend/backend/copilot/message_dedup.py new file mode 100644 index 0000000000..2af13b559a --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/message_dedup.py @@ -0,0 +1,71 @@ +"""Per-request idempotency lock for the /stream endpoint. + +Prevents duplicate executor tasks from concurrent or retried POSTs (e.g. k8s +rolling-deploy retries, nginx upstream retries, rapid double-clicks). + +Lifecycle +--------- +1. ``acquire()`` — computes a stable hash of (session_id, message, file_ids) + and atomically sets a Redis NX key. Returns a ``_DedupLock`` on success or + ``None`` when the key already exists (duplicate request). +2. ``release()`` — deletes the key. Must be called on turn completion or turn + error so the next legitimate send is never blocked. +3. On client disconnect (``GeneratorExit``) the lock must NOT be released — + the backend turn is still running, and releasing would reopen the duplicate + window for infra-level retries. The 30 s TTL is the safety net. +""" + +import hashlib +import logging + +from backend.data.redis_client import get_redis_async + +logger = logging.getLogger(__name__) + +_KEY_PREFIX = "chat:msg_dedup" +_TTL_SECONDS = 30 + + +class _DedupLock: + def __init__(self, key: str, redis) -> None: + self._key = key + self._redis = redis + + async def release(self) -> None: + """Best-effort key deletion. The TTL handles failures silently.""" + try: + await self._redis.delete(self._key) + except Exception: + pass + + +async def acquire_dedup_lock( + session_id: str, + message: str | None, + file_ids: list[str] | None, +) -> _DedupLock | None: + """Acquire the idempotency lock for this (session, message, files) tuple. + + Returns a ``_DedupLock`` when the lock is freshly acquired (first request). + Returns ``None`` when a duplicate is detected (lock already held). + Returns ``None`` when there is nothing to deduplicate (no message, no files). + """ + if not message and not file_ids: + return None + + sorted_ids = ":".join(sorted(file_ids or [])) + content_hash = hashlib.sha256( + f"{session_id}:{message or ''}:{sorted_ids}".encode() + ).hexdigest()[:16] + key = f"{_KEY_PREFIX}:{session_id}:{content_hash}" + + redis = await get_redis_async() + acquired = await redis.set(key, "1", ex=_TTL_SECONDS, nx=True) + if not acquired: + logger.warning( + f"[STREAM] Duplicate user message blocked for session {session_id}, " + f"hash={content_hash} — returning empty SSE", + ) + return None + + return _DedupLock(key, redis) diff --git a/autogpt_platform/backend/backend/copilot/message_dedup_test.py b/autogpt_platform/backend/backend/copilot/message_dedup_test.py new file mode 100644 index 0000000000..935ddd36b6 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/message_dedup_test.py @@ -0,0 +1,94 @@ +"""Unit tests for backend.copilot.message_dedup.""" + +from unittest.mock import AsyncMock + +import pytest +import pytest_mock + +from backend.copilot.message_dedup import _KEY_PREFIX, acquire_dedup_lock + + +def _patch_redis(mocker: pytest_mock.MockerFixture, *, set_returns): + mock_redis = AsyncMock() + mock_redis.set = AsyncMock(return_value=set_returns) + mocker.patch( + "backend.copilot.message_dedup.get_redis_async", + new_callable=AsyncMock, + return_value=mock_redis, + ) + return mock_redis + + +@pytest.mark.asyncio +async def test_acquire_returns_none_when_no_message_no_files( + mocker: pytest_mock.MockerFixture, +) -> None: + """Nothing to deduplicate — no Redis call made, None returned.""" + mock_redis = _patch_redis(mocker, set_returns=True) + result = await acquire_dedup_lock("sess-1", None, None) + assert result is None + mock_redis.set.assert_not_called() + + +@pytest.mark.asyncio +async def test_acquire_returns_lock_on_first_request( + mocker: pytest_mock.MockerFixture, +) -> None: + """First request acquires the lock and returns a _DedupLock.""" + mock_redis = _patch_redis(mocker, set_returns=True) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + mock_redis.set.assert_called_once() + key_arg = mock_redis.set.call_args.args[0] + assert key_arg.startswith(f"{_KEY_PREFIX}:sess-1:") + + +@pytest.mark.asyncio +async def test_acquire_returns_none_on_duplicate( + mocker: pytest_mock.MockerFixture, +) -> None: + """Duplicate request (NX fails) returns None to signal the caller.""" + _patch_redis(mocker, set_returns=None) + result = await acquire_dedup_lock("sess-1", "hello", None) + assert result is None + + +@pytest.mark.asyncio +async def test_acquire_key_stable_across_file_order( + mocker: pytest_mock.MockerFixture, +) -> None: + """File IDs are sorted before hashing so order doesn't affect the key.""" + mock_redis_1 = _patch_redis(mocker, set_returns=True) + await acquire_dedup_lock("sess-1", "msg", ["b", "a"]) + key_ab = mock_redis_1.set.call_args.args[0] + + mock_redis_2 = _patch_redis(mocker, set_returns=True) + await acquire_dedup_lock("sess-1", "msg", ["a", "b"]) + key_ba = mock_redis_2.set.call_args.args[0] + + assert key_ab == key_ba + + +@pytest.mark.asyncio +async def test_release_deletes_key( + mocker: pytest_mock.MockerFixture, +) -> None: + """release() calls Redis delete exactly once.""" + mock_redis = _patch_redis(mocker, set_returns=True) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + await lock.release() + mock_redis.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_release_swallows_redis_error( + mocker: pytest_mock.MockerFixture, +) -> None: + """release() must not raise even when Redis delete fails.""" + mock_redis = _patch_redis(mocker, set_returns=True) + mock_redis.delete = AsyncMock(side_effect=RuntimeError("redis down")) + lock = await acquire_dedup_lock("sess-1", "hello", None) + assert lock is not None + await lock.release() # must not raise + mock_redis.delete.assert_called_once() diff --git a/autogpt_platform/backend/backend/copilot/rate_limit.py b/autogpt_platform/backend/backend/copilot/rate_limit.py index f72d36de23..3124c28992 100644 --- a/autogpt_platform/backend/backend/copilot/rate_limit.py +++ b/autogpt_platform/backend/backend/copilot/rate_limit.py @@ -302,6 +302,7 @@ async def record_token_usage( *, cache_read_tokens: int = 0, cache_creation_tokens: int = 0, + model_cost_multiplier: float = 1.0, ) -> None: """Record token usage for a user across all windows. @@ -315,12 +316,17 @@ async def record_token_usage( ``prompt_tokens`` should be the *uncached* input count (``input_tokens`` from the API response). Cache counts are passed separately. + ``model_cost_multiplier`` scales the final weighted total to reflect + relative model cost. Use 5.0 for Opus (5× more expensive than Sonnet) + so that Opus turns deplete the rate limit faster, proportional to cost. + Args: user_id: The user's ID. prompt_tokens: Uncached input tokens. completion_tokens: Output tokens. cache_read_tokens: Tokens served from prompt cache (10% cost). cache_creation_tokens: Tokens written to prompt cache (25% cost). + model_cost_multiplier: Relative model cost factor (1.0 = Sonnet, 5.0 = Opus). """ prompt_tokens = max(0, prompt_tokens) completion_tokens = max(0, completion_tokens) @@ -332,7 +338,9 @@ async def record_token_usage( + round(cache_creation_tokens * 0.25) + round(cache_read_tokens * 0.1) ) - total = weighted_input + completion_tokens + total = round( + (weighted_input + completion_tokens) * max(1.0, model_cost_multiplier) + ) if total <= 0: return @@ -340,11 +348,12 @@ async def record_token_usage( prompt_tokens + cache_read_tokens + cache_creation_tokens + completion_tokens ) logger.info( - "Recording token usage for %s: raw=%d, weighted=%d " + "Recording token usage for %s: raw=%d, weighted=%d, multiplier=%.1fx " "(uncached=%d, cache_read=%d@10%%, cache_create=%d@25%%, output=%d)", user_id[:8], raw_total, total, + model_cost_multiplier, prompt_tokens, cache_read_tokens, cache_creation_tokens, diff --git a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md index 35b4a348b9..145354b704 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md +++ b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md @@ -34,9 +34,13 @@ Steps: always inspect the current graph first so you know exactly what to change. Avoid using `include_graph=true` with broad keyword searches, as fetching multiple graphs at once is expensive and consumes LLM context budget. -2. **Discover blocks**: Call `find_block(query, include_schemas=true)` to +2. **Discover blocks**: Call `find_block(query, include_schemas=true, for_agent_generation=true)` to search for relevant blocks. This returns block IDs, names, descriptions, - and full input/output schemas. + and full input/output schemas. The `for_agent_generation=true` flag is + required to surface graph-only blocks such as AgentInputBlock, + AgentDropdownInputBlock, AgentOutputBlock, OrchestratorBlock, + and WebhookBlock and MCPToolBlock. (When running MCP tools interactively + in CoPilot outside agent generation, use `run_mcp_tool` instead.) 3. **Find library agents**: Call `find_library_agent` to discover reusable agents that can be composed as sub-agents via `AgentExecutorBlock`. 4. **Generate/modify JSON**: Build or modify the agent JSON using block schemas: @@ -177,6 +181,12 @@ To compose agents using other agents as sub-agents: ### Using MCP Tools (MCPToolBlock) +> **Agent graph vs CoPilot direct execution**: This section covers embedding MCP +> tools as persistent nodes in an agent graph. When running MCP tools directly in +> CoPilot (outside agent generation), use `run_mcp_tool` instead — it handles +> server discovery and authentication interactively. Use `MCPToolBlock` here only +> when the user wants the MCP call baked into a reusable agent graph. + To use an MCP (Model Context Protocol) tool as a node in the agent: 1. The user must specify which MCP server URL and tool name they want 2. Create an `MCPToolBlock` node (ID: `a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4`) diff --git a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py index 7077337a79..9305320fea 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py @@ -207,7 +207,7 @@ class TestConfigDefaults: def test_max_budget_usd_default(self): cfg = _make_config() - assert cfg.claude_agent_max_budget_usd == 15.0 + assert cfg.claude_agent_max_budget_usd == 10.0 def test_max_thinking_tokens_default(self): cfg = _make_config() diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 8c70dad90f..77e0c16a88 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -56,7 +56,7 @@ from backend.executor.cluster_lock import AsyncClusterLock from backend.util.exceptions import NotFoundError from backend.util.settings import Settings -from ..config import ChatConfig, CopilotMode +from ..config import ChatConfig, CopilotLlmModel, CopilotMode from ..constants import ( COPILOT_ERROR_PREFIX, COPILOT_RETRYABLE_ERROR_PREFIX, @@ -132,6 +132,11 @@ _MAX_STREAM_ATTEMPTS = 3 # self-correct. The limit is generous to allow recovery attempts. _EMPTY_TOOL_CALL_LIMIT = 5 +# Cost multiplier for Opus model turns — Opus is ~5× more expensive than Sonnet +# ($15/$75 vs $3/$15 per M tokens). Applied to rate-limit counters so Opus +# turns deplete quota proportionally faster. +_OPUS_COST_MULTIPLIER = 5.0 + # User-facing error shown when the empty-tool-call circuit breaker trips. _CIRCUIT_BREAKER_ERROR_MSG = ( "AutoPilot was unable to complete the tool call " @@ -674,6 +679,48 @@ def _resolve_fallback_model() -> str | None: return _normalize_model_name(raw) +async def _resolve_model_and_multiplier( + model: "CopilotLlmModel | None", + session_id: str, +) -> tuple[str | None, float]: + """Resolve the SDK model string and rate-limit cost multiplier for a turn. + + Priority (highest first): + 1. Explicit per-request ``model`` tier from the frontend toggle. + 2. Global config default (``_resolve_sdk_model()``). + + Returns a ``(sdk_model, cost_multiplier)`` pair. + ``sdk_model`` is ``None`` when the Claude Code subscription default applies. + ``cost_multiplier`` is 5.0 for Opus, 1.0 otherwise. + """ + sdk_model = _resolve_sdk_model() + + if model == "advanced": + sdk_model = _normalize_model_name("anthropic/claude-opus-4-6") + logger.info( + "[SDK] [%s] Per-request model override: advanced (%s)", + session_id[:12] if session_id else "?", + sdk_model, + ) + return sdk_model, _OPUS_COST_MULTIPLIER + + if model == "standard": + # Reset to config default — respects subscription mode (None = CLI default). + sdk_model = _resolve_sdk_model() + logger.info( + "[SDK] [%s] Per-request model override: standard (%s)", + session_id[:12] if session_id else "?", + sdk_model or "subscription-default", + ) + return sdk_model, 1.0 + + # No per-request override; derive multiplier from final resolved model. + cost_multiplier = ( + _OPUS_COST_MULTIPLIER if sdk_model and "opus" in sdk_model else 1.0 + ) + return sdk_model, cost_multiplier + + _MAX_TRANSIENT_BACKOFF_SECONDS = 30 @@ -1865,15 +1912,20 @@ async def _run_stream_attempt( # cache_read_input_tokens = served from cache # cache_creation_input_tokens = written to cache if sdk_msg.usage: - state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens", 0) - state.usage.cache_read_tokens += sdk_msg.usage.get( - "cache_read_input_tokens", 0 + # Use `or 0` instead of a default in .get() because + # OpenRouter may include the key with a null value (e.g. + # {"cache_read_input_tokens": null}) for models that don't + # yet report cache tokens, making .get("key", 0) return + # None rather than the fallback 0. + state.usage.prompt_tokens += sdk_msg.usage.get("input_tokens") or 0 + state.usage.cache_read_tokens += ( + sdk_msg.usage.get("cache_read_input_tokens") or 0 ) - state.usage.cache_creation_tokens += sdk_msg.usage.get( - "cache_creation_input_tokens", 0 + state.usage.cache_creation_tokens += ( + sdk_msg.usage.get("cache_creation_input_tokens") or 0 ) - state.usage.completion_tokens += sdk_msg.usage.get( - "output_tokens", 0 + state.usage.completion_tokens += ( + sdk_msg.usage.get("output_tokens") or 0 ) logger.info( "%s Token usage: uncached=%d, cache_read=%d, " @@ -2150,6 +2202,7 @@ async def stream_chat_completion_sdk( file_ids: list[str] | None = None, permissions: "CopilotPermissions | None" = None, mode: CopilotMode | None = None, + model: CopilotLlmModel | None = None, **_kwargs: Any, ) -> AsyncIterator[StreamBaseResponse]: """Stream chat completion using Claude Agent SDK. @@ -2160,6 +2213,9 @@ async def stream_chat_completion_sdk( saved to the SDK working directory for the Read tool. mode: Accepted for signature compatibility with the baseline path. The SDK path does not currently branch on this value. + model: Per-request model preference from the frontend toggle. + 'advanced' → Claude Opus; 'standard' → global config default. + Takes priority over per-user LaunchDarkly targeting. """ _ = mode # SDK path ignores the requested mode. @@ -2274,6 +2330,10 @@ async def stream_chat_completion_sdk( turn_cache_creation_tokens = 0 turn_cost_usd: float | None = None graphiti_enabled = False + # Defaults ensure the finally block can always reference these safely even when + # an early return (e.g. sdk_cwd error) skips their normal assignment below. + sdk_model: str | None = None + model_cost_multiplier: float = 1.0 # Make sure there is no more code between the lock acquisition and try-block. try: @@ -2487,7 +2547,10 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server(use_e2b=use_e2b) - sdk_model = _resolve_sdk_model() + # Resolve model and cost multiplier (request tier → config default). + sdk_model, model_cost_multiplier = await _resolve_model_and_multiplier( + model, session_id + ) # Track SDK-internal compaction (PreCompact hook → start, next msg → end) compaction = CompactionTracker() @@ -3188,8 +3251,9 @@ async def stream_chat_completion_sdk( cache_creation_tokens=turn_cache_creation_tokens, log_prefix=log_prefix, cost_usd=turn_cost_usd, - model=config.model, + model=sdk_model or config.model, provider="anthropic", + model_cost_multiplier=model_cost_multiplier, ) # --- Persist session messages --- diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index be2c46bdbb..9d8b4bb135 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -20,7 +20,9 @@ from .service import ( _is_prompt_too_long, _is_tool_only_message, _iter_sdk_messages, + _normalize_model_name, _reduce_context, + _TokenUsage, ) # --------------------------------------------------------------------------- @@ -350,3 +352,128 @@ class TestIsParallelContinuation: msg = MagicMock(spec=AssistantMessage) msg.content = [self._make_tool_block()] assert _is_tool_only_message(msg) is True + + +# --------------------------------------------------------------------------- +# _normalize_model_name — used by per-request model override +# --------------------------------------------------------------------------- + + +class TestNormalizeModelName: + """Unit tests for the model-name normalisation helper. + + The per-request model toggle calls _normalize_model_name with either + ``"anthropic/claude-opus-4-6"`` (for 'advanced') or ``config.model`` (for + 'standard'). These tests verify the OpenRouter/provider-prefix stripping + that keeps the value compatible with the Claude CLI. + """ + + def test_strips_anthropic_prefix(self): + assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" + + def test_strips_openai_prefix(self): + assert _normalize_model_name("openai/gpt-4o") == "gpt-4o" + + def test_strips_google_prefix(self): + assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash" + + def test_already_normalized_unchanged(self): + assert ( + _normalize_model_name("claude-sonnet-4-20250514") + == "claude-sonnet-4-20250514" + ) + + def test_empty_string_unchanged(self): + assert _normalize_model_name("") == "" + + def test_opus_model_roundtrip(self): + """The exact string used for the 'opus' toggle strips correctly.""" + assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" + + def test_sonnet_openrouter_model(self): + """Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly.""" + assert _normalize_model_name("anthropic/claude-sonnet-4") == "claude-sonnet-4" + + +# --------------------------------------------------------------------------- +# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug) +# --------------------------------------------------------------------------- + + +class TestTokenUsageNullSafety: + """Verify that ResultMessage.usage dicts with null-valued cache fields + (as emitted by OpenRouter for the initial streaming event before real + token counts are available) do not crash the accumulator. + + Before the fix, dict.get("cache_read_input_tokens", 0) returned None + when the key existed with a null value, causing 'int += None' TypeError. + """ + + def _apply_usage(self, usage: dict, acc: _TokenUsage) -> None: + """Mirror the production accumulation in sdk/service.py.""" + acc.prompt_tokens += usage.get("input_tokens") or 0 + acc.cache_read_tokens += usage.get("cache_read_input_tokens") or 0 + acc.cache_creation_tokens += usage.get("cache_creation_input_tokens") or 0 + acc.completion_tokens += usage.get("output_tokens") or 0 + + def test_null_cache_tokens_do_not_crash(self): + """OpenRouter initial event: cache keys present with null value.""" + usage = { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_input_tokens": None, + "cache_creation_input_tokens": None, + } + acc = _TokenUsage() + self._apply_usage(usage, acc) # must not raise TypeError + assert acc.prompt_tokens == 0 + assert acc.cache_read_tokens == 0 + assert acc.cache_creation_tokens == 0 + assert acc.completion_tokens == 0 + + def test_real_cache_tokens_are_accumulated(self): + """OpenRouter final event: real cache token counts are captured.""" + usage = { + "input_tokens": 10, + "output_tokens": 349, + "cache_read_input_tokens": 16600, + "cache_creation_input_tokens": 512, + } + acc = _TokenUsage() + self._apply_usage(usage, acc) + assert acc.prompt_tokens == 10 + assert acc.cache_read_tokens == 16600 + assert acc.cache_creation_tokens == 512 + assert acc.completion_tokens == 349 + + def test_absent_cache_keys_default_to_zero(self): + """Minimal usage dict without cache keys defaults correctly.""" + usage = {"input_tokens": 5, "output_tokens": 20} + acc = _TokenUsage() + self._apply_usage(usage, acc) + assert acc.prompt_tokens == 5 + assert acc.cache_read_tokens == 0 + assert acc.cache_creation_tokens == 0 + assert acc.completion_tokens == 20 + + def test_multi_turn_accumulation(self): + """Null event followed by real event: only real tokens counted.""" + null_event = { + "input_tokens": 0, + "output_tokens": 0, + "cache_read_input_tokens": None, + "cache_creation_input_tokens": None, + } + real_event = { + "input_tokens": 10, + "output_tokens": 349, + "cache_read_input_tokens": 16600, + "cache_creation_input_tokens": 512, + } + acc = _TokenUsage() + self._apply_usage(null_event, acc) + self._apply_usage(real_event, acc) + assert acc.prompt_tokens == 10 + assert acc.cache_read_tokens == 16600 + assert acc.cache_creation_tokens == 512 + assert acc.completion_tokens == 349 diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 163b8c1bab..030763dbca 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -1149,3 +1149,50 @@ async def unsubscribe_from_session( ) logger.debug(f"Successfully unsubscribed from session {session_id}") + + +async def disconnect_all_listeners(session_id: str) -> int: + """Cancel every active listener task for *session_id*. + + Called when the frontend switches away from a session and wants the + backend to release resources immediately rather than waiting for the + XREAD timeout. + + Scope / limitations (best-effort optimisation, not a correctness primitive): + - Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request + lands on a different worker than the one serving the SSE, no listener + is cancelled here — the SSE worker still releases on its XREAD timeout. + - Session-scoped (not subscriber-scoped): cancels every active listener + for the session on this pod. In the rare case a single user opens two + SSE connections to the same session on the same pod (e.g. two tabs), + both would be torn down. Cross-pod, subscriber-scoped cancellation + would require a Redis pub/sub fan-out with per-listener tokens; that + is not implemented here because the XREAD timeout already bounds the + worst case. + + Returns the number of listener tasks that were cancelled. + """ + to_cancel: list[tuple[int, asyncio.Task]] = [ + (qid, task) + for qid, (sid, task) in list(_listener_sessions.items()) + if sid == session_id and not task.done() + ] + + for qid, task in to_cancel: + _listener_sessions.pop(qid, None) + task.cancel() + + cancelled = 0 + for _qid, task in to_cancel: + try: + await asyncio.wait_for(task, timeout=5.0) + except asyncio.CancelledError: + cancelled += 1 + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(f"Error cancelling listener for session {session_id}: {e}") + + if cancelled: + logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}") + return cancelled diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py new file mode 100644 index 0000000000..a09940a4a8 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -0,0 +1,110 @@ +"""Tests for disconnect_all_listeners in stream_registry.""" + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot import stream_registry + + +@pytest.fixture(autouse=True) +def _clear_listener_sessions(): + stream_registry._listener_sessions.clear() + yield + stream_registry._listener_sessions.clear() + + +async def _sleep_forever(): + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + raise + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_cancels_matching_session(): + task_a = asyncio.create_task(_sleep_forever()) + task_b = asyncio.create_task(_sleep_forever()) + task_other = asyncio.create_task(_sleep_forever()) + + stream_registry._listener_sessions[1] = ("sess-1", task_a) + stream_registry._listener_sessions[2] = ("sess-1", task_b) + stream_registry._listener_sessions[3] = ("sess-other", task_other) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 2 + assert task_a.cancelled() + assert task_b.cancelled() + assert not task_other.done() + # Matching entries are removed, non-matching entries remain. + assert 1 not in stream_registry._listener_sessions + assert 2 not in stream_registry._listener_sessions + assert 3 in stream_registry._listener_sessions + finally: + task_other.cancel() + try: + await task_other + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_no_match_returns_zero(): + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-other", task) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-missing") + + assert cancelled == 0 + assert not task.done() + assert 1 in stream_registry._listener_sessions + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_skips_already_done_tasks(): + async def _noop(): + return None + + done_task = asyncio.create_task(_noop()) + await done_task + stream_registry._listener_sessions[1] = ("sess-1", done_task) + + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + # Done tasks are filtered out before cancellation. + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_empty_registry(): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_timeout_not_counted(): + """Tasks that don't respond to cancellation (timeout) are not counted.""" + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-1", task) + + with patch.object( + asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError) + ): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 0 + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/autogpt_platform/backend/backend/copilot/token_tracking.py b/autogpt_platform/backend/backend/copilot/token_tracking.py index e84b64d449..19406ced93 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking.py @@ -96,6 +96,7 @@ async def persist_and_record_usage( cost_usd: float | str | None = None, model: str | None = None, provider: str = "open_router", + model_cost_multiplier: float = 1.0, ) -> int: """Persist token usage to session and record for rate limiting. @@ -109,6 +110,9 @@ async def persist_and_record_usage( log_prefix: Prefix for log messages (e.g. "[SDK]", "[Baseline]"). cost_usd: Optional cost for logging (float from SDK, str otherwise). provider: Cost provider name (e.g. "anthropic", "open_router"). + model_cost_multiplier: Relative model cost factor for rate limiting + (1.0 = Sonnet/default, 5.0 = Opus). Scales the token counter so + more expensive models deplete the rate limit proportionally faster. Returns: The computed total_tokens (prompt + completion; cache excluded). @@ -163,6 +167,7 @@ async def persist_and_record_usage( completion_tokens=completion_tokens, cache_read_tokens=cache_read_tokens, cache_creation_tokens=cache_creation_tokens, + model_cost_multiplier=model_cost_multiplier, ) except Exception as usage_err: logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err) diff --git a/autogpt_platform/backend/backend/copilot/token_tracking_test.py b/autogpt_platform/backend/backend/copilot/token_tracking_test.py index 04c7667368..11757ce541 100644 --- a/autogpt_platform/backend/backend/copilot/token_tracking_test.py +++ b/autogpt_platform/backend/backend/copilot/token_tracking_test.py @@ -230,6 +230,7 @@ class TestRateLimitRecording: completion_tokens=50, cache_read_tokens=1000, cache_creation_tokens=200, + model_cost_multiplier=1.0, ) @pytest.mark.asyncio diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block.py b/autogpt_platform/backend/backend/copilot/tools/find_block.py index 0cbc3ba047..130e26562b 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block.py @@ -74,6 +74,15 @@ class FindBlockTool(BaseTool): "description": "Include full input/output schemas (for agent JSON generation).", "default": False, }, + "for_agent_generation": { + "type": "boolean", + "description": ( + "Set to true when searching for blocks to use inside an agent graph " + "(e.g. AgentInputBlock, AgentOutputBlock, OrchestratorBlock). " + "Bypasses the CoPilot-only filter so graph-only blocks are visible." + ), + "default": False, + }, }, "required": ["query"], } @@ -88,6 +97,7 @@ class FindBlockTool(BaseTool): session: ChatSession, query: str = "", include_schemas: bool = False, + for_agent_generation: bool = False, **kwargs, ) -> ToolResponseBase: """Search for blocks matching the query. @@ -97,6 +107,8 @@ class FindBlockTool(BaseTool): session: Chat session query: Search query include_schemas: Whether to include block schemas in results + for_agent_generation: When True, bypasses the CoPilot exclusion filter + so graph-only blocks (INPUT, OUTPUT, ORCHESTRATOR, etc.) are visible. Returns: BlockListResponse: List of matching blocks @@ -123,34 +135,36 @@ class FindBlockTool(BaseTool): suggestions=["Search for an alternative block by name"], session_id=session_id, ) - if ( + is_excluded = ( block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES or block.id in COPILOT_EXCLUDED_BLOCK_IDS - ): - if block.block_type == BlockType.MCP_TOOL: + ) + if is_excluded: + # Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are + # exposed when building an agent graph so the LLM can inspect + # their schemas and wire them as nodes. In CoPilot direct use + # they are not executable — guide the LLM to the right tool. + if not for_agent_generation: + if block.block_type == BlockType.MCP_TOOL: + message = ( + f"Block '{block.name}' (ID: {block.id}) cannot be " + "run directly in CoPilot. Use run_mcp_tool for " + "interactive MCP execution, or call find_block with " + "for_agent_generation=true to embed it in an agent graph." + ) + else: + message = ( + f"Block '{block.name}' (ID: {block.id}) is not available " + "in CoPilot. It can only be used within agent graphs." + ) return NoResultsResponse( - message=( - f"Block '{block.name}' (ID: {block.id}) is not " - "runnable through find_block/run_block. Use " - "run_mcp_tool instead." - ), + message=message, suggestions=[ - "Use run_mcp_tool to discover and run this MCP tool", "Search for an alternative block by name", + "Use this block in an agent graph instead", ], session_id=session_id, ) - return NoResultsResponse( - message=( - f"Block '{block.name}' (ID: {block.id}) is not available " - "in CoPilot. It can only be used within agent graphs." - ), - suggestions=[ - "Search for an alternative block by name", - "Use this block in an agent graph instead", - ], - session_id=session_id, - ) # Check block-level permissions — hide denied blocks entirely perms = get_current_permissions() @@ -221,8 +235,9 @@ class FindBlockTool(BaseTool): if not block or block.disabled: continue - # Skip blocks excluded from CoPilot (graph-only blocks) - if ( + # Graph-only blocks (INPUT, OUTPUT, MCP_TOOL, AGENT, etc.) are + # skipped in CoPilot direct use but surfaced for agent graph building. + if not for_agent_generation and ( block.block_type in COPILOT_EXCLUDED_BLOCK_TYPES or block.id in COPILOT_EXCLUDED_BLOCK_IDS ): diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py index 64a7fe3788..d99672daa2 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py @@ -12,7 +12,7 @@ from .find_block import ( COPILOT_EXCLUDED_BLOCK_TYPES, FindBlockTool, ) -from .models import BlockListResponse +from .models import BlockListResponse, NoResultsResponse _TEST_USER_ID = "test-user-find-block" @@ -166,6 +166,194 @@ class TestFindBlockFiltering: assert len(response.blocks) == 1 assert response.blocks[0].id == "normal-block-id" + @pytest.mark.asyncio(loop_scope="session") + async def test_for_agent_generation_exposes_excluded_blocks_in_search(self): + """With for_agent_generation=True, excluded block types appear in search results.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "input-block-id", "score": 0.9}, + {"content_id": "output-block-id", "score": 0.8}, + ] + input_block = make_mock_block("input-block-id", "Agent Input", BlockType.INPUT) + output_block = make_mock_block( + "output-block-id", "Agent Output", BlockType.OUTPUT + ) + + def mock_get_block(block_id): + return { + "input-block-id": input_block, + "output-block-id": output_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="agent input", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + block_ids = {b.id for b in response.blocks} + assert "input-block-id" in block_ids + assert "output-block-id" in block_ids + + @pytest.mark.asyncio(loop_scope="session") + async def test_mcp_tool_exposed_with_for_agent_generation_in_search(self): + """MCP_TOOL blocks appear in search results when for_agent_generation=True.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "mcp-block-id", "score": 0.9}, + {"content_id": "standard-block-id", "score": 0.8}, + ] + mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL) + standard_block = make_mock_block( + "standard-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + "mcp-block-id": mcp_block, + "standard-block-id": standard_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="mcp tool", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + assert any(b.id == "mcp-block-id" for b in response.blocks) + assert any(b.id == "standard-block-id" for b in response.blocks) + + @pytest.mark.asyncio(loop_scope="session") + async def test_mcp_tool_excluded_without_for_agent_generation_in_search(self): + """MCP_TOOL blocks are excluded from search in normal CoPilot mode.""" + session = make_session(user_id=_TEST_USER_ID) + + search_results = [ + {"content_id": "mcp-block-id", "score": 0.9}, + {"content_id": "standard-block-id", "score": 0.8}, + ] + mcp_block = make_mock_block("mcp-block-id", "MCP Tool", BlockType.MCP_TOOL) + standard_block = make_mock_block( + "standard-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + "mcp-block-id": mcp_block, + "standard-block-id": standard_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="mcp tool", + for_agent_generation=False, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 1 + assert response.blocks[0].id == "standard-block-id" + + @pytest.mark.asyncio(loop_scope="session") + async def test_for_agent_generation_exposes_excluded_ids_in_search(self): + """With for_agent_generation=True, excluded block IDs appear in search results.""" + session = make_session(user_id=_TEST_USER_ID) + orchestrator_id = next(iter(COPILOT_EXCLUDED_BLOCK_IDS)) + + search_results = [ + {"content_id": orchestrator_id, "score": 0.9}, + {"content_id": "normal-block-id", "score": 0.8}, + ] + orchestrator_block = make_mock_block( + orchestrator_id, "Orchestrator", BlockType.STANDARD + ) + normal_block = make_mock_block( + "normal-block-id", "Normal Block", BlockType.STANDARD + ) + + def mock_get_block(block_id): + return { + orchestrator_id: orchestrator_block, + "normal-block-id": normal_block, + }.get(block_id) + + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + + with patch( + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, + ): + with patch( + "backend.copilot.tools.find_block.get_block", + side_effect=mock_get_block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query="orchestrator", + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert len(response.blocks) == 2 + block_ids = {b.id for b in response.blocks} + assert orchestrator_id in block_ids + assert "normal-block-id" in block_ids + @pytest.mark.asyncio(loop_scope="session") async def test_response_size_average_chars_per_block(self): """Measure average chars per block in the serialized response.""" @@ -549,8 +737,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) @pytest.mark.asyncio(loop_scope="session") @@ -571,8 +757,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "disabled" in response.message.lower() @@ -592,8 +776,6 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=block_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "not available" in response.message.lower() @@ -613,7 +795,74 @@ class TestFindBlockDirectLookup: user_id=_TEST_USER_ID, session=session, query=orchestrator_id ) - from .models import NoResultsResponse - assert isinstance(response, NoResultsResponse) assert "not available" in response.message.lower() + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_excluded_block_type_allowed_with_for_agent_generation( + self, + ): + """With for_agent_generation=True, excluded block types (INPUT) are visible.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "Agent Input Block", BlockType.INPUT) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert response.count == 1 + assert response.blocks[0].id == block_id + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_mcp_tool_exposed_with_for_agent_generation(self): + """MCP_TOOL blocks are returned by UUID lookup when for_agent_generation=True.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=True, + ) + + assert isinstance(response, BlockListResponse) + assert response.blocks[0].id == block_id + + @pytest.mark.asyncio(loop_scope="session") + async def test_uuid_lookup_mcp_tool_excluded_without_for_agent_generation(self): + """MCP_TOOL blocks are excluded by UUID lookup in normal CoPilot mode.""" + session = make_session(user_id=_TEST_USER_ID) + block_id = "a1b2c3d4-e5f6-4a7b-8c9d-0e1f2a3b4c5d" + block = make_mock_block(block_id, "MCP Tool", BlockType.MCP_TOOL) + + with patch( + "backend.copilot.tools.find_block.get_block", + return_value=block, + ): + tool = FindBlockTool() + response = await tool._execute( + user_id=_TEST_USER_ID, + session=session, + query=block_id, + for_agent_generation=False, + ) + + assert isinstance(response, NoResultsResponse) + assert "run_mcp_tool" in response.message diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts index 712aaaf508..9580ef349a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/helpers.test.ts @@ -1,6 +1,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; import { IMPERSONATION_HEADER_NAME } from "@/lib/constants"; -import { getCopilotAuthHeaders } from "../helpers"; +import { getCopilotAuthHeaders, getSendSuppressionReason } from "../helpers"; +import type { UIMessage } from "ai"; vi.mock("@/lib/supabase/actions", () => ({ getWebSocketToken: vi.fn(), @@ -72,3 +73,71 @@ describe("getCopilotAuthHeaders", () => { ); }); }); + +// ─── getSendSuppressionReason ───────────────────────────────────────────────── + +function makeUserMsg(text: string): UIMessage { + return { + id: "msg-1", + role: "user", + content: text, + parts: [{ type: "text", text }], + } as UIMessage; +} + +describe("getSendSuppressionReason", () => { + it("returns null when no dedup context exists (fresh ref)", () => { + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: null, + messages: [], + }); + expect(result).toBeNull(); + }); + + it("returns 'reconnecting' when reconnect is scheduled regardless of text", () => { + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: true, + lastSubmittedText: null, + messages: [], + }); + expect(result).toBe("reconnecting"); + }); + + it("returns 'duplicate' when same text was submitted and is the last user message", () => { + // This is the core regression test: after a successful turn the ref + // is intentionally NOT cleared to null, so submitting the same text + // again is caught here. + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: "hello", + messages: [makeUserMsg("hello")], + }); + expect(result).toBe("duplicate"); + }); + + it("returns null when same ref text but different last user message (different question)", () => { + // User asked "hello" before, got a reply, then asked a different question + // — the last user message in chat is now different, so no suppression. + const result = getSendSuppressionReason({ + text: "hello", + isReconnectScheduled: false, + lastSubmittedText: "hello", + messages: [makeUserMsg("hello"), makeUserMsg("something else")], + }); + expect(result).toBeNull(); + }); + + it("returns null when text differs from lastSubmittedText", () => { + const result = getSendSuppressionReason({ + text: "new question", + isReconnectScheduled: false, + lastSubmittedText: "old question", + messages: [makeUserMsg("old question")], + }); + expect(result).toBeNull(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts index f993daf58d..fd95bbdb2c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/store.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, it, beforeEach, vi } from "vitest"; +import { describe, expect, it, beforeEach, afterEach, vi } from "vitest"; import { useCopilotUIStore } from "../store"; vi.mock("@sentry/nextjs", () => ({ @@ -22,7 +22,8 @@ describe("useCopilotUIStore", () => { isNotificationsEnabled: false, isSoundEnabled: true, showNotificationDialog: false, - copilotMode: "extended_thinking", + copilotChatMode: "extended_thinking", + copilotLlmModel: "standard", }); }); @@ -154,35 +155,52 @@ describe("useCopilotUIStore", () => { }); }); - describe("copilotMode", () => { + describe("copilotChatMode", () => { it("defaults to extended_thinking", () => { - expect(useCopilotUIStore.getState().copilotMode).toBe( + expect(useCopilotUIStore.getState().copilotChatMode).toBe( "extended_thinking", ); }); it("sets mode to fast", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - expect(useCopilotUIStore.getState().copilotMode).toBe("fast"); + useCopilotUIStore.getState().setCopilotChatMode("fast"); + expect(useCopilotUIStore.getState().copilotChatMode).toBe("fast"); }); it("sets mode back to extended_thinking", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - useCopilotUIStore.getState().setCopilotMode("extended_thinking"); - expect(useCopilotUIStore.getState().copilotMode).toBe( + useCopilotUIStore.getState().setCopilotChatMode("fast"); + useCopilotUIStore.getState().setCopilotChatMode("extended_thinking"); + expect(useCopilotUIStore.getState().copilotChatMode).toBe( "extended_thinking", ); }); - it("does not persist mode to localStorage", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); - expect(window.localStorage.getItem("copilot-mode")).toBeNull(); + it("persists mode to localStorage", () => { + useCopilotUIStore.getState().setCopilotChatMode("fast"); + expect(window.localStorage.getItem("copilot-mode")).toBe("fast"); + }); + }); + + describe("copilotLlmModel", () => { + it("defaults to standard", () => { + expect(useCopilotUIStore.getState().copilotLlmModel).toBe("standard"); + }); + + it("sets model to advanced", () => { + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); + expect(useCopilotUIStore.getState().copilotLlmModel).toBe("advanced"); + }); + + it("persists model to localStorage", () => { + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); + expect(window.localStorage.getItem("copilot-model")).toBe("advanced"); }); }); describe("clearCopilotLocalData", () => { it("resets state and clears localStorage keys", () => { - useCopilotUIStore.getState().setCopilotMode("fast"); + useCopilotUIStore.getState().setCopilotChatMode("fast"); + useCopilotUIStore.getState().setCopilotLlmModel("advanced"); useCopilotUIStore.getState().setNotificationsEnabled(true); useCopilotUIStore.getState().toggleSound(); useCopilotUIStore.getState().addCompletedSession("s1"); @@ -190,7 +208,8 @@ describe("useCopilotUIStore", () => { useCopilotUIStore.getState().clearCopilotLocalData(); const state = useCopilotUIStore.getState(); - expect(state.copilotMode).toBe("extended_thinking"); + expect(state.copilotChatMode).toBe("extended_thinking"); + expect(state.copilotLlmModel).toBe("standard"); expect(state.isNotificationsEnabled).toBe(false); expect(state.isSoundEnabled).toBe(true); expect(state.completedSessionIDs.size).toBe(0); @@ -198,6 +217,8 @@ describe("useCopilotUIStore", () => { window.localStorage.getItem("copilot-notifications-enabled"), ).toBeNull(); expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull(); + expect(window.localStorage.getItem("copilot-mode")).toBeNull(); + expect(window.localStorage.getItem("copilot-model")).toBeNull(); expect( window.localStorage.getItem("copilot-completed-sessions"), ).toBeNull(); @@ -222,3 +243,24 @@ describe("useCopilotUIStore", () => { }); }); }); + +describe("useCopilotUIStore localStorage initialisation", () => { + afterEach(() => { + vi.resetModules(); + window.localStorage.clear(); + }); + + it("reads fast chat mode from localStorage on store creation", async () => { + window.localStorage.setItem("copilot-mode", "fast"); + vi.resetModules(); + const { useCopilotUIStore: fresh } = await import("../store"); + expect(fresh.getState().copilotChatMode).toBe("fast"); + }); + + it("reads advanced model from localStorage on store creation", async () => { + window.localStorage.setItem("copilot-model", "advanced"); + vi.resetModules(); + const { useCopilotUIStore: fresh } = await import("../store"); + expect(fresh.getState().copilotLlmModel).toBe("advanced"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx index d1e1ca4f9d..b6fedb722e 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/ChatInput.tsx @@ -13,6 +13,7 @@ import { ChangeEvent, useEffect, useState } from "react"; import { AttachmentMenu } from "./components/AttachmentMenu"; import { DryRunToggleButton } from "./components/DryRunToggleButton"; import { FileChips } from "./components/FileChips"; +import { ModelToggleButton } from "./components/ModelToggleButton"; import { ModeToggleButton } from "./components/ModeToggleButton"; import { RecordingButton } from "./components/RecordingButton"; import { RecordingIndicator } from "./components/RecordingIndicator"; @@ -50,16 +51,22 @@ export function ChatInput({ onDroppedFilesConsumed, hasSession = false, }: Props) { - const { copilotMode, setCopilotMode, isDryRun, setIsDryRun } = - useCopilotUIStore(); + const { + copilotChatMode, + setCopilotChatMode, + copilotLlmModel, + setCopilotLlmModel, + isDryRun, + setIsDryRun, + } = useCopilotUIStore(); const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION); const showDryRunToggle = showModeToggle; const [files, setFiles] = useState([]); function handleToggleMode() { const next = - copilotMode === "extended_thinking" ? "fast" : "extended_thinking"; - setCopilotMode(next); + copilotChatMode === "extended_thinking" ? "fast" : "extended_thinking"; + setCopilotChatMode(next); toast({ title: next === "fast" @@ -72,6 +79,21 @@ export function ChatInput({ }); } + function handleToggleModel() { + const next = copilotLlmModel === "advanced" ? "standard" : "advanced"; + setCopilotLlmModel(next); + toast({ + title: + next === "advanced" + ? "Switched to Advanced model" + : "Switched to Standard model", + description: + next === "advanced" + ? "Using the highest-capability model." + : "Using the balanced standard model.", + }); + } + function handleToggleDryRun() { const next = !isDryRun; setIsDryRun(next); @@ -198,10 +220,16 @@ export function ChatInput({ /> {showModeToggle && !isStreaming && ( )} + {showModeToggle && !isStreaming && ( + + )} {showDryRunToggle && (!hasSession || isDryRun) && ( { +const mockSetCopilotChatMode = vi.fn((mode: string) => { mockCopilotMode = mode; }); +let mockCopilotLlmModel = "standard"; +const mockSetCopilotLlmModel = vi.fn((model: string) => { + mockCopilotLlmModel = model; +}); + vi.mock("@/app/(platform)/copilot/store", () => ({ useCopilotUIStore: () => ({ - copilotMode: mockCopilotMode, - setCopilotMode: mockSetCopilotMode, + copilotChatMode: mockCopilotMode, + setCopilotChatMode: mockSetCopilotChatMode, + copilotLlmModel: mockCopilotLlmModel, + setCopilotLlmModel: mockSetCopilotLlmModel, initialPrompt: null, setInitialPrompt: vi.fn(), }), @@ -107,6 +114,7 @@ afterEach(() => { cleanup(); vi.clearAllMocks(); mockCopilotMode = "extended_thinking"; + mockCopilotLlmModel = "standard"; }); describe("ChatInput mode toggle", () => { @@ -141,7 +149,7 @@ describe("ChatInput mode toggle", () => { mockCopilotMode = "extended_thinking"; render(); fireEvent.click(screen.getByLabelText(/switch to fast mode/i)); - expect(mockSetCopilotMode).toHaveBeenCalledWith("fast"); + expect(mockSetCopilotChatMode).toHaveBeenCalledWith("fast"); }); it("toggles from fast to extended_thinking on click", () => { @@ -149,7 +157,7 @@ describe("ChatInput mode toggle", () => { mockCopilotMode = "fast"; render(); fireEvent.click(screen.getByLabelText(/switch to extended thinking/i)); - expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking"); + expect(mockSetCopilotChatMode).toHaveBeenCalledWith("extended_thinking"); }); it("hides toggle button when streaming", () => { @@ -187,3 +195,69 @@ describe("ChatInput mode toggle", () => { ); }); }); + +describe("ChatInput model toggle", () => { + it("renders model toggle button when flag is enabled", () => { + mockFlagValue = true; + render(); + expect(screen.getByLabelText(/switch to advanced model/i)).toBeDefined(); + }); + + it("does not render model toggle when flag is disabled", () => { + mockFlagValue = false; + render(); + expect( + screen.queryByLabelText(/switch to (advanced|standard) model/i), + ).toBeNull(); + }); + + it("toggles from standard to advanced on click", () => { + mockFlagValue = true; + mockCopilotLlmModel = "standard"; + render(); + fireEvent.click(screen.getByLabelText(/switch to advanced model/i)); + expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("advanced"); + }); + + it("toggles from advanced to standard on click", () => { + mockFlagValue = true; + mockCopilotLlmModel = "advanced"; + render(); + fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + expect(mockSetCopilotLlmModel).toHaveBeenCalledWith("standard"); + }); + + it("hides model toggle when streaming", () => { + mockFlagValue = true; + render(); + expect( + screen.queryByLabelText(/switch to (advanced|standard) model/i), + ).toBeNull(); + }); + + it("shows a toast when switching to advanced", async () => { + const { toast } = await import("@/components/molecules/Toast/use-toast"); + mockFlagValue = true; + mockCopilotLlmModel = "standard"; + render(); + fireEvent.click(screen.getByLabelText(/switch to advanced model/i)); + expect(toast).toHaveBeenCalledWith( + expect.objectContaining({ + title: expect.stringMatching(/switched to advanced model/i), + }), + ); + }); + + it("shows a toast when switching to standard", async () => { + const { toast } = await import("@/components/molecules/Toast/use-toast"); + mockFlagValue = true; + mockCopilotLlmModel = "advanced"; + render(); + fireEvent.click(screen.getByLabelText(/switch to standard model/i)); + expect(toast).toHaveBeenCalledWith( + expect.objectContaining({ + title: expect.stringMatching(/switched to standard model/i), + }), + ); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx new file mode 100644 index 0000000000..cb3bc25f4f --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/ModelToggleButton.tsx @@ -0,0 +1,38 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { Cpu } from "@phosphor-icons/react"; +import type { CopilotLlmModel } from "../../../store"; + +interface Props { + model: CopilotLlmModel; + onToggle: () => void; +} + +export function ModelToggleButton({ model, onToggle }: Props) { + const isAdvanced = model === "advanced"; + return ( + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx new file mode 100644 index 0000000000..a77cb5b6f4 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/ChatInput/components/__tests__/ModelToggleButton.test.tsx @@ -0,0 +1,36 @@ +import { render, screen, fireEvent, cleanup } from "@testing-library/react"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { ModelToggleButton } from "../ModelToggleButton"; + +afterEach(cleanup); + +describe("ModelToggleButton", () => { + it("shows no label when model is standard", () => { + render(); + expect(screen.queryByText("Advanced")).toBeNull(); + }); + + it("shows Advanced label when model is advanced", () => { + render(); + expect(screen.getByText("Advanced")).toBeTruthy(); + }); + + it("calls onToggle when clicked", () => { + const onToggle = vi.fn(); + render(); + fireEvent.click(screen.getByRole("button")); + expect(onToggle).toHaveBeenCalledTimes(1); + }); + + it("sets aria-pressed=false for standard", () => { + render(); + const btn = screen.getByLabelText("Switch to Advanced model"); + expect(btn.getAttribute("aria-pressed")).toBe("false"); + }); + + it("sets aria-pressed=true for advanced", () => { + render(); + const btn = screen.getByLabelText("Switch to Standard model"); + expect(btn.getAttribute("aria-pressed")).toBe("true"); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 66c437eb86..34e2bea51a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation"; import { getWebSocketToken } from "@/lib/supabase/actions"; import type { UIMessage } from "ai"; +import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat"; + export const ORIGINAL_TITLE = "AutoGPT"; /** @@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend( } /** - * Deduplicate messages by ID and by content fingerprint. + * Fire-and-forget: tell the backend to release XREAD listeners for a session. + * + * Called on session switch so the backend doesn't wait for its 5-10 s timeout + * before cleaning up. Failures are silently ignored — the backend will + * eventually clean up on its own. + */ +export function disconnectSessionStream(sessionId: string): void { + deleteV2DisconnectSessionStream(sessionId).catch(() => {}); +} + +/** + * Deduplicate messages by ID and by consecutive content fingerprint. * * ID dedup catches exact duplicates within the same source. * Content dedup uses a composite key of `role + preceding-user-message-id + diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts index d63c0bd76a..d8dcbd132c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/store.ts @@ -53,6 +53,9 @@ export const DEFAULT_PANEL_WIDTH = 600; /** Autopilot response mode. */ export type CopilotMode = "extended_thinking" | "fast"; +/** Per-request model tier. 'standard' = current default; 'advanced' = highest-capability. */ +export type CopilotLlmModel = "standard" | "advanced"; + const isClient = typeof window !== "undefined"; function getPersistedWidth(): number { @@ -134,8 +137,12 @@ interface CopilotUIState { goBackArtifact: () => void; /** Autopilot mode: 'extended_thinking' (default) or 'fast'. */ - copilotMode: CopilotMode; - setCopilotMode: (mode: CopilotMode) => void; + copilotChatMode: CopilotMode; + setCopilotChatMode: (mode: CopilotMode) => void; + + /** Model tier: 'standard' (default) or 'advanced' (highest-capability). */ + copilotLlmModel: CopilotLlmModel; + setCopilotLlmModel: (model: CopilotLlmModel) => void; /** Developer dry-run mode: sessions created with dry_run=true. */ isDryRun: boolean; @@ -298,9 +305,22 @@ export const useCopilotUIStore = create((set) => ({ }; }), - copilotMode: "extended_thinking", - setCopilotMode: (mode) => { - set({ copilotMode: mode }); + copilotChatMode: (() => { + const saved = isClient ? storage.get(Key.COPILOT_MODE) : null; + return saved === "fast" ? "fast" : "extended_thinking"; + })(), + setCopilotChatMode: (mode) => { + storage.set(Key.COPILOT_MODE, mode); + set({ copilotChatMode: mode }); + }, + + copilotLlmModel: (() => { + const saved = isClient ? storage.get(Key.COPILOT_MODEL) : null; + return saved === "advanced" ? "advanced" : "standard"; + })(), + setCopilotLlmModel: (model) => { + storage.set(Key.COPILOT_MODEL, model); + set({ copilotLlmModel: model }); }, isDryRun: isClient && storage.get(Key.COPILOT_DRY_RUN) === "true", @@ -322,6 +342,8 @@ export const useCopilotUIStore = create((set) => ({ storage.clean(Key.COPILOT_ARTIFACT_PANEL_WIDTH); storage.clean(Key.COPILOT_COMPLETED_SESSIONS); storage.clean(Key.COPILOT_DRY_RUN); + storage.clean(Key.COPILOT_MODE); + storage.clean(Key.COPILOT_MODEL); set({ completedSessionIDs: new Set(), isNotificationsEnabled: false, @@ -334,7 +356,8 @@ export const useCopilotUIStore = create((set) => ({ activeArtifact: null, history: [], }, - copilotMode: "extended_thinking", + copilotChatMode: "extended_thinking", + copilotLlmModel: "standard", isDryRun: false, }); if (isClient) { diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index f8b0387c6b..01302c9f81 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -42,7 +42,8 @@ export function useCopilotPage() { setSessionToDelete, isDrawerOpen, setDrawerOpen, - copilotMode, + copilotChatMode, + copilotLlmModel, isDryRun, } = useCopilotUIStore(); @@ -78,7 +79,8 @@ export function useCopilotPage() { hydratedMessages, hasActiveStream, refetchSession, - copilotMode: isModeToggleEnabled ? copilotMode : undefined, + copilotMode: isModeToggleEnabled ? copilotChatMode : undefined, + copilotModel: isModeToggleEnabled ? copilotLlmModel : undefined, }); const { olderMessages, hasMore, isLoadingMore, loadMore } = diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 918047d3d8..666b87bfba 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -17,8 +17,9 @@ import { hasActiveBackendStream, resolveInProgressTools, getSendSuppressionReason, + disconnectSessionStream, } from "./helpers"; -import type { CopilotMode } from "./store"; +import type { CopilotLlmModel, CopilotMode } from "./store"; const RECONNECT_BASE_DELAY_MS = 1_000; const RECONNECT_MAX_ATTEMPTS = 3; @@ -33,6 +34,8 @@ interface UseCopilotStreamArgs { refetchSession: () => Promise<{ data?: unknown }>; /** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */ copilotMode: CopilotMode | undefined; + /** Model tier override. `undefined` = let backend decide. */ + copilotModel: CopilotLlmModel | undefined; } export function useCopilotStream({ @@ -41,17 +44,20 @@ export function useCopilotStream({ hasActiveStream, refetchSession, copilotMode, + copilotModel, }: UseCopilotStreamArgs) { const queryClient = useQueryClient(); const [rateLimitMessage, setRateLimitMessage] = useState(null); function dismissRateLimit() { setRateLimitMessage(null); } - // Use a ref for copilotMode so the transport closure always reads the - // latest value without recreating the DefaultChatTransport (which would + // Use refs for copilotMode and copilotModel so the transport closure always reads + // the latest value without recreating the DefaultChatTransport (which would // reset useChat's internal Chat instance and break mid-session streaming). const copilotModeRef = useRef(copilotMode); copilotModeRef.current = copilotMode; + const copilotModelRef = useRef(copilotModel); + copilotModelRef.current = copilotModel; // Connect directly to the Python backend for SSE, bypassing the Next.js // serverless proxy. This eliminates the Vercel 800s function timeout that @@ -83,6 +89,7 @@ export function useCopilotStream({ context: null, file_ids: fileIds && fileIds.length > 0 ? fileIds : null, mode: copilotModeRef.current ?? null, + model: copilotModelRef.current ?? null, }, headers: await getCopilotAuthHeaders(), }; @@ -147,16 +154,15 @@ export function useCopilotStream({ reconnectTimerRef.current = setTimeout(() => { isReconnectScheduledRef.current = false; setIsReconnectScheduled(false); - // Strip any stale in-progress assistant message before resuming. - // The backend replays from "0-0", so the partial message would - // otherwise sit alongside the fully-replayed version. + // Strip the stale in-progress assistant message before resuming — + // the backend replays from "0-0", so keeping it would duplicate parts. setMessages((prev) => { if (prev.length > 0 && prev[prev.length - 1].role === "assistant") { return prev.slice(0, -1); } return prev; }); - resumeStream(); + resumeStreamRef.current(); }, delay); } @@ -254,6 +260,14 @@ export function useCopilotStream({ }, }); + // Keep stable refs to sdkStop and resumeStream so that async callbacks + // (session-switch cleanup, wake re-sync, reconnect timer) always call the + // latest version without stale-closure bugs. + const sdkStopRef = useRef(sdkStop); + sdkStopRef.current = sdkStop; + const resumeStreamRef = useRef(resumeStream); + resumeStreamRef.current = resumeStream; + // Wrap sdkSendMessage to guard against re-sending the user message during a // reconnect cycle. If the session already has the message (i.e. we are in a // reconnect/resume flow), only GET-resume is safe — never re-POST. @@ -380,7 +394,7 @@ export function useCopilotStream({ } return prev; }); - await resumeStream(); + await resumeStreamRef.current(); } // If !backendActive, the refetch will update hydratedMessages via // React Query, and the hydration effect below will merge them in. @@ -403,7 +417,7 @@ export function useCopilotStream({ return () => { document.removeEventListener("visibilitychange", onVisibilityChange); }; - }, [refetchSession, setMessages, resumeStream]); + }, [refetchSession, setMessages]); // Hydrate messages from REST API when not actively streaming useEffect(() => { @@ -419,8 +433,34 @@ export function useCopilotStream({ // Track resume state per session const hasResumedRef = useRef>(new Map()); - // Clean up reconnect state on session switch + // Clean up reconnect state on session switch. + // Abort the old stream's in-flight fetch and tell the backend to release + // its XREAD listeners immediately (fire-and-forget). + const prevStreamSessionRef = useRef(sessionId); useEffect(() => { + const prevSid = prevStreamSessionRef.current; + prevStreamSessionRef.current = sessionId; + + const isSwitching = Boolean(prevSid && prevSid !== sessionId); + if (isSwitching) { + // Mark BEFORE stopping so the old stream's async onError (which fires + // after the abort) sees the flag and short-circuits the reconnect path. + // Without this, the AbortError can queue a reconnect against the new + // session's `sessionId` (captured in the fresh onError closure). + isUserStoppingRef.current = true; + sdkStopRef.current(); + disconnectSessionStream(prevSid!); + // Schedule the reset as a task (not a microtask) so it runs AFTER the + // aborted fetch's onError has fired — otherwise the new session would + // be stuck with the "user stopping" flag set, preventing auto-resume + // when hydration detects an active backend stream. + setTimeout(() => { + isUserStoppingRef.current = false; + }, 0); + } else { + isUserStoppingRef.current = false; + } + clearTimeout(reconnectTimerRef.current); reconnectTimerRef.current = undefined; reconnectAttemptsRef.current = 0; @@ -428,7 +468,6 @@ export function useCopilotStream({ setIsReconnectScheduled(false); setRateLimitMessage(null); hasShownDisconnectToast.current = false; - isUserStoppingRef.current = false; lastSubmittedMsgRef.current = null; setReconnectExhausted(false); setIsSyncing(false); @@ -458,7 +497,12 @@ export function useCopilotStream({ if (status === "ready") { reconnectAttemptsRef.current = 0; hasShownDisconnectToast.current = false; - lastSubmittedMsgRef.current = null; + // Intentionally NOT clearing lastSubmittedMsgRef here: keeping the last + // submitted text prevents getSendSuppressionReason from allowing a + // duplicate POST of the same message immediately after a successful turn + // (the "duplicate" branch checks both the ref and the visible last user + // message, so legitimate re-sends after a different reply are still + // allowed). setReconnectExhausted(false); } } @@ -495,15 +539,8 @@ export function useCopilotStream({ return prev; }); - resumeStream(); - }, [ - sessionId, - hasActiveStream, - hydratedMessages, - status, - resumeStream, - setMessages, - ]); + resumeStreamRef.current(); + }, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]); // Clear messages when session is null useEffect(() => { diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 732ef569d9..f93caabbb1 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1606,6 +1606,35 @@ } }, "/api/chat/sessions/{session_id}/stream": { + "delete": { + "tags": ["v2", "chat", "chat"], + "summary": "Disconnect Session Stream", + "description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.", + "operationId": "deleteV2DisconnectSessionStream", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Session Id" } + } + ], + "responses": { + "204": { "description": "Successful Response" }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, "get": { "tags": ["v2", "chat", "chat"], "summary": "Resume Session Stream", @@ -13931,6 +13960,14 @@ ], "title": "Mode", "description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)." + }, + "model": { + "anyOf": [ + { "type": "string", "enum": ["standard", "advanced"] }, + { "type": "null" } + ], + "title": "Model", + "description": "Model tier: 'standard' for the default model, 'advanced' for the highest-capability model. If None, the server applies per-user LD targeting then falls back to config." } }, "type": "object", diff --git a/autogpt_platform/frontend/src/services/storage/local-storage.ts b/autogpt_platform/frontend/src/services/storage/local-storage.ts index de31967d53..b5c0392ecd 100644 --- a/autogpt_platform/frontend/src/services/storage/local-storage.ts +++ b/autogpt_platform/frontend/src/services/storage/local-storage.ts @@ -17,6 +17,7 @@ export enum Key { COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed", COPILOT_ARTIFACT_PANEL_WIDTH = "copilot-artifact-panel-width", COPILOT_MODE = "copilot-mode", + COPILOT_MODEL = "copilot-model", COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions", COPILOT_DRY_RUN = "copilot-dry-run", }