diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index f8c3e3b804..ac7325e201 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -381,6 +381,31 @@ async def delete_session( return Response(status_code=204) +@router.delete( + "/sessions/{session_id}/stream", + dependencies=[Security(auth.requires_user)], + status_code=204, +) +async def disconnect_session_stream( + session_id: str, + user_id: Annotated[str, Security(auth.get_user_id)], +) -> Response: + """Disconnect all active SSE listeners for a session. + + Called by the frontend when the user switches away from a chat so the + backend releases XREAD listeners immediately rather than waiting for + the 5-10 s timeout. + """ + session = await get_chat_session(session_id, user_id) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session {session_id} not found or access denied", + ) + await stream_registry.disconnect_all_listeners(session_id) + return Response(status_code=204) + + @router.patch( "/sessions/{session_id}/title", summary="Update session title", diff --git a/autogpt_platform/backend/backend/api/features/chat/routes_test.py b/autogpt_platform/backend/backend/api/features/chat/routes_test.py index f3896c7098..74259b3463 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes_test.py @@ -677,3 +677,48 @@ class TestStripInjectedContext: result = _strip_injected_context(msg) # Without a role, the helper short-circuits without touching content. assert result["content"] == "hello" + + +# ─── DELETE /sessions/{id}/stream — disconnect listeners ────────────── + + +def test_disconnect_stream_returns_204_and_awaits_registry( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mock_session = MagicMock() + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=mock_session, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + return_value=2, + ) + + response = client.delete("/sessions/sess-1/stream") + + assert response.status_code == 204 + mock_disconnect.assert_awaited_once_with("sess-1") + + +def test_disconnect_stream_returns_404_when_session_missing( + mocker: pytest_mock.MockerFixture, + test_user_id: str, +) -> None: + mocker.patch( + "backend.api.features.chat.routes.get_chat_session", + new_callable=AsyncMock, + return_value=None, + ) + mock_disconnect = mocker.patch( + "backend.api.features.chat.routes.stream_registry.disconnect_all_listeners", + new_callable=AsyncMock, + ) + + response = client.delete("/sessions/unknown-session/stream") + + assert response.status_code == 404 + mock_disconnect.assert_not_awaited() diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 163b8c1bab..030763dbca 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -1149,3 +1149,50 @@ async def unsubscribe_from_session( ) logger.debug(f"Successfully unsubscribed from session {session_id}") + + +async def disconnect_all_listeners(session_id: str) -> int: + """Cancel every active listener task for *session_id*. + + Called when the frontend switches away from a session and wants the + backend to release resources immediately rather than waiting for the + XREAD timeout. + + Scope / limitations (best-effort optimisation, not a correctness primitive): + - Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request + lands on a different worker than the one serving the SSE, no listener + is cancelled here — the SSE worker still releases on its XREAD timeout. + - Session-scoped (not subscriber-scoped): cancels every active listener + for the session on this pod. In the rare case a single user opens two + SSE connections to the same session on the same pod (e.g. two tabs), + both would be torn down. Cross-pod, subscriber-scoped cancellation + would require a Redis pub/sub fan-out with per-listener tokens; that + is not implemented here because the XREAD timeout already bounds the + worst case. + + Returns the number of listener tasks that were cancelled. + """ + to_cancel: list[tuple[int, asyncio.Task]] = [ + (qid, task) + for qid, (sid, task) in list(_listener_sessions.items()) + if sid == session_id and not task.done() + ] + + for qid, task in to_cancel: + _listener_sessions.pop(qid, None) + task.cancel() + + cancelled = 0 + for _qid, task in to_cancel: + try: + await asyncio.wait_for(task, timeout=5.0) + except asyncio.CancelledError: + cancelled += 1 + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(f"Error cancelling listener for session {session_id}: {e}") + + if cancelled: + logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}") + return cancelled diff --git a/autogpt_platform/backend/backend/copilot/stream_registry_test.py b/autogpt_platform/backend/backend/copilot/stream_registry_test.py new file mode 100644 index 0000000000..a09940a4a8 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/stream_registry_test.py @@ -0,0 +1,110 @@ +"""Tests for disconnect_all_listeners in stream_registry.""" + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot import stream_registry + + +@pytest.fixture(autouse=True) +def _clear_listener_sessions(): + stream_registry._listener_sessions.clear() + yield + stream_registry._listener_sessions.clear() + + +async def _sleep_forever(): + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + raise + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_cancels_matching_session(): + task_a = asyncio.create_task(_sleep_forever()) + task_b = asyncio.create_task(_sleep_forever()) + task_other = asyncio.create_task(_sleep_forever()) + + stream_registry._listener_sessions[1] = ("sess-1", task_a) + stream_registry._listener_sessions[2] = ("sess-1", task_b) + stream_registry._listener_sessions[3] = ("sess-other", task_other) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 2 + assert task_a.cancelled() + assert task_b.cancelled() + assert not task_other.done() + # Matching entries are removed, non-matching entries remain. + assert 1 not in stream_registry._listener_sessions + assert 2 not in stream_registry._listener_sessions + assert 3 in stream_registry._listener_sessions + finally: + task_other.cancel() + try: + await task_other + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_no_match_returns_zero(): + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-other", task) + + try: + cancelled = await stream_registry.disconnect_all_listeners("sess-missing") + + assert cancelled == 0 + assert not task.done() + assert 1 in stream_registry._listener_sessions + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_skips_already_done_tasks(): + async def _noop(): + return None + + done_task = asyncio.create_task(_noop()) + await done_task + stream_registry._listener_sessions[1] = ("sess-1", done_task) + + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + # Done tasks are filtered out before cancellation. + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_empty_registry(): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + assert cancelled == 0 + + +@pytest.mark.asyncio +async def test_disconnect_all_listeners_timeout_not_counted(): + """Tasks that don't respond to cancellation (timeout) are not counted.""" + task = asyncio.create_task(_sleep_forever()) + stream_registry._listener_sessions[1] = ("sess-1", task) + + with patch.object( + asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError) + ): + cancelled = await stream_registry.disconnect_all_listeners("sess-1") + + assert cancelled == 0 + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 66c437eb86..34e2bea51a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation"; import { getWebSocketToken } from "@/lib/supabase/actions"; import type { UIMessage } from "ai"; +import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat"; + export const ORIGINAL_TITLE = "AutoGPT"; /** @@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend( } /** - * Deduplicate messages by ID and by content fingerprint. + * Fire-and-forget: tell the backend to release XREAD listeners for a session. + * + * Called on session switch so the backend doesn't wait for its 5-10 s timeout + * before cleaning up. Failures are silently ignored — the backend will + * eventually clean up on its own. + */ +export function disconnectSessionStream(sessionId: string): void { + deleteV2DisconnectSessionStream(sessionId).catch(() => {}); +} + +/** + * Deduplicate messages by ID and by consecutive content fingerprint. * * ID dedup catches exact duplicates within the same source. * Content dedup uses a composite key of `role + preceding-user-message-id + diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 14ea672bfb..85709f23d9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -17,6 +17,7 @@ import { hasActiveBackendStream, resolveInProgressTools, getSendSuppressionReason, + disconnectSessionStream, } from "./helpers"; import type { CopilotLlmModel, CopilotMode } from "./store"; @@ -153,16 +154,15 @@ export function useCopilotStream({ reconnectTimerRef.current = setTimeout(() => { isReconnectScheduledRef.current = false; setIsReconnectScheduled(false); - // Strip any stale in-progress assistant message before resuming. - // The backend replays from "0-0", so the partial message would - // otherwise sit alongside the fully-replayed version. + // Strip the stale in-progress assistant message before resuming — + // the backend replays from "0-0", so keeping it would duplicate parts. setMessages((prev) => { if (prev.length > 0 && prev[prev.length - 1].role === "assistant") { return prev.slice(0, -1); } return prev; }); - resumeStream(); + resumeStreamRef.current(); }, delay); } @@ -260,6 +260,14 @@ export function useCopilotStream({ }, }); + // Keep stable refs to sdkStop and resumeStream so that async callbacks + // (session-switch cleanup, wake re-sync, reconnect timer) always call the + // latest version without stale-closure bugs. + const sdkStopRef = useRef(sdkStop); + sdkStopRef.current = sdkStop; + const resumeStreamRef = useRef(resumeStream); + resumeStreamRef.current = resumeStream; + // Wrap sdkSendMessage to guard against re-sending the user message during a // reconnect cycle. If the session already has the message (i.e. we are in a // reconnect/resume flow), only GET-resume is safe — never re-POST. @@ -386,7 +394,7 @@ export function useCopilotStream({ } return prev; }); - await resumeStream(); + await resumeStreamRef.current(); } // If !backendActive, the refetch will update hydratedMessages via // React Query, and the hydration effect below will merge them in. @@ -409,7 +417,7 @@ export function useCopilotStream({ return () => { document.removeEventListener("visibilitychange", onVisibilityChange); }; - }, [refetchSession, setMessages, resumeStream]); + }, [refetchSession, setMessages]); // Hydrate messages from REST API when not actively streaming useEffect(() => { @@ -425,8 +433,34 @@ export function useCopilotStream({ // Track resume state per session const hasResumedRef = useRef>(new Map()); - // Clean up reconnect state on session switch + // Clean up reconnect state on session switch. + // Abort the old stream's in-flight fetch and tell the backend to release + // its XREAD listeners immediately (fire-and-forget). + const prevStreamSessionRef = useRef(sessionId); useEffect(() => { + const prevSid = prevStreamSessionRef.current; + prevStreamSessionRef.current = sessionId; + + const isSwitching = Boolean(prevSid && prevSid !== sessionId); + if (isSwitching) { + // Mark BEFORE stopping so the old stream's async onError (which fires + // after the abort) sees the flag and short-circuits the reconnect path. + // Without this, the AbortError can queue a reconnect against the new + // session's `sessionId` (captured in the fresh onError closure). + isUserStoppingRef.current = true; + sdkStopRef.current(); + disconnectSessionStream(prevSid!); + // Schedule the reset as a task (not a microtask) so it runs AFTER the + // aborted fetch's onError has fired — otherwise the new session would + // be stuck with the "user stopping" flag set, preventing auto-resume + // when hydration detects an active backend stream. + setTimeout(() => { + isUserStoppingRef.current = false; + }, 0); + } else { + isUserStoppingRef.current = false; + } + clearTimeout(reconnectTimerRef.current); reconnectTimerRef.current = undefined; reconnectAttemptsRef.current = 0; @@ -434,7 +468,6 @@ export function useCopilotStream({ setIsReconnectScheduled(false); setRateLimitMessage(null); hasShownDisconnectToast.current = false; - isUserStoppingRef.current = false; lastSubmittedMsgRef.current = null; setReconnectExhausted(false); setIsSyncing(false); @@ -501,15 +534,8 @@ export function useCopilotStream({ return prev; }); - resumeStream(); - }, [ - sessionId, - hasActiveStream, - hydratedMessages, - status, - resumeStream, - setMessages, - ]); + resumeStreamRef.current(); + }, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]); // Clear messages when session is null useEffect(() => { diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index e1f3120f9f..e5ad3bf296 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1633,6 +1633,35 @@ } }, "/api/chat/sessions/{session_id}/stream": { + "delete": { + "tags": ["v2", "chat", "chat"], + "summary": "Disconnect Session Stream", + "description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.", + "operationId": "deleteV2DisconnectSessionStream", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "session_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Session Id" } + } + ], + "responses": { + "204": { "description": "Successful Response" }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, "get": { "tags": ["v2", "chat", "chat"], "summary": "Resume Session Stream",