mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/copilot-model-toggle-styling
This commit is contained in:
@@ -381,6 +381,31 @@ async def delete_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}/stream",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
)
|
||||
async def disconnect_session_stream(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""Disconnect all active SSE listeners for a session.
|
||||
|
||||
Called by the frontend when the user switches away from a chat so the
|
||||
backend releases XREAD listeners immediately rather than waiting for
|
||||
the 5-10 s timeout.
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
await stream_registry.disconnect_all_listeners(session_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/title",
|
||||
summary="Update session title",
|
||||
|
||||
@@ -677,3 +677,48 @@ class TestStripInjectedContext:
|
||||
result = _strip_injected_context(msg)
|
||||
# Without a role, the helper short-circuits without touching content.
|
||||
assert result["content"] == "hello"
|
||||
|
||||
|
||||
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_204_and_awaits_registry(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mock_session = MagicMock()
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_session,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/sess-1/stream")
|
||||
|
||||
assert response.status_code == 204
|
||||
mock_disconnect.assert_awaited_once_with("sess-1")
|
||||
|
||||
|
||||
def test_disconnect_stream_returns_404_when_session_missing(
|
||||
mocker: pytest_mock.MockerFixture,
|
||||
test_user_id: str,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"backend.api.features.chat.routes.get_chat_session",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
)
|
||||
mock_disconnect = mocker.patch(
|
||||
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
|
||||
response = client.delete("/sessions/unknown-session/stream")
|
||||
|
||||
assert response.status_code == 404
|
||||
mock_disconnect.assert_not_awaited()
|
||||
|
||||
@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from session {session_id}")
|
||||
|
||||
|
||||
async def disconnect_all_listeners(session_id: str) -> int:
|
||||
"""Cancel every active listener task for *session_id*.
|
||||
|
||||
Called when the frontend switches away from a session and wants the
|
||||
backend to release resources immediately rather than waiting for the
|
||||
XREAD timeout.
|
||||
|
||||
Scope / limitations (best-effort optimisation, not a correctness primitive):
|
||||
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
|
||||
lands on a different worker than the one serving the SSE, no listener
|
||||
is cancelled here — the SSE worker still releases on its XREAD timeout.
|
||||
- Session-scoped (not subscriber-scoped): cancels every active listener
|
||||
for the session on this pod. In the rare case a single user opens two
|
||||
SSE connections to the same session on the same pod (e.g. two tabs),
|
||||
both would be torn down. Cross-pod, subscriber-scoped cancellation
|
||||
would require a Redis pub/sub fan-out with per-listener tokens; that
|
||||
is not implemented here because the XREAD timeout already bounds the
|
||||
worst case.
|
||||
|
||||
Returns the number of listener tasks that were cancelled.
|
||||
"""
|
||||
to_cancel: list[tuple[int, asyncio.Task]] = [
|
||||
(qid, task)
|
||||
for qid, (sid, task) in list(_listener_sessions.items())
|
||||
if sid == session_id and not task.done()
|
||||
]
|
||||
|
||||
for qid, task in to_cancel:
|
||||
_listener_sessions.pop(qid, None)
|
||||
task.cancel()
|
||||
|
||||
cancelled = 0
|
||||
for _qid, task in to_cancel:
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
cancelled += 1
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling listener for session {session_id}: {e}")
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
|
||||
return cancelled
|
||||
|
||||
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
110
autogpt_platform/backend/backend/copilot/stream_registry_test.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for disconnect_all_listeners in stream_registry."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_listener_sessions():
|
||||
stream_registry._listener_sessions.clear()
|
||||
yield
|
||||
stream_registry._listener_sessions.clear()
|
||||
|
||||
|
||||
async def _sleep_forever():
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_cancels_matching_session():
|
||||
task_a = asyncio.create_task(_sleep_forever())
|
||||
task_b = asyncio.create_task(_sleep_forever())
|
||||
task_other = asyncio.create_task(_sleep_forever())
|
||||
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task_a)
|
||||
stream_registry._listener_sessions[2] = ("sess-1", task_b)
|
||||
stream_registry._listener_sessions[3] = ("sess-other", task_other)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 2
|
||||
assert task_a.cancelled()
|
||||
assert task_b.cancelled()
|
||||
assert not task_other.done()
|
||||
# Matching entries are removed, non-matching entries remain.
|
||||
assert 1 not in stream_registry._listener_sessions
|
||||
assert 2 not in stream_registry._listener_sessions
|
||||
assert 3 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task_other.cancel()
|
||||
try:
|
||||
await task_other
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_no_match_returns_zero():
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-other", task)
|
||||
|
||||
try:
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
|
||||
|
||||
assert cancelled == 0
|
||||
assert not task.done()
|
||||
assert 1 in stream_registry._listener_sessions
|
||||
finally:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_skips_already_done_tasks():
|
||||
async def _noop():
|
||||
return None
|
||||
|
||||
done_task = asyncio.create_task(_noop())
|
||||
await done_task
|
||||
stream_registry._listener_sessions[1] = ("sess-1", done_task)
|
||||
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
# Done tasks are filtered out before cancellation.
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_empty_registry():
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
assert cancelled == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_all_listeners_timeout_not_counted():
|
||||
"""Tasks that don't respond to cancellation (timeout) are not counted."""
|
||||
task = asyncio.create_task(_sleep_forever())
|
||||
stream_registry._listener_sessions[1] = ("sess-1", task)
|
||||
|
||||
with patch.object(
|
||||
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
):
|
||||
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
|
||||
|
||||
assert cancelled == 0
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
|
||||
import { getWebSocketToken } from "@/lib/supabase/actions";
|
||||
import type { UIMessage } from "ai";
|
||||
|
||||
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
|
||||
export const ORIGINAL_TITLE = "AutoGPT";
|
||||
|
||||
/**
|
||||
@@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend(
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by content fingerprint.
|
||||
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
|
||||
*
|
||||
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
|
||||
* before cleaning up. Failures are silently ignored — the backend will
|
||||
* eventually clean up on its own.
|
||||
*/
|
||||
export function disconnectSessionStream(sessionId: string): void {
|
||||
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by consecutive content fingerprint.
|
||||
*
|
||||
* ID dedup catches exact duplicates within the same source.
|
||||
* Content dedup uses a composite key of `role + preceding-user-message-id +
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
hasActiveBackendStream,
|
||||
resolveInProgressTools,
|
||||
getSendSuppressionReason,
|
||||
disconnectSessionStream,
|
||||
} from "./helpers";
|
||||
import type { CopilotLlmModel, CopilotMode } from "./store";
|
||||
|
||||
@@ -153,16 +154,15 @@ export function useCopilotStream({
|
||||
reconnectTimerRef.current = setTimeout(() => {
|
||||
isReconnectScheduledRef.current = false;
|
||||
setIsReconnectScheduled(false);
|
||||
// Strip any stale in-progress assistant message before resuming.
|
||||
// The backend replays from "0-0", so the partial message would
|
||||
// otherwise sit alongside the fully-replayed version.
|
||||
// Strip the stale in-progress assistant message before resuming —
|
||||
// the backend replays from "0-0", so keeping it would duplicate parts.
|
||||
setMessages((prev) => {
|
||||
if (prev.length > 0 && prev[prev.length - 1].role === "assistant") {
|
||||
return prev.slice(0, -1);
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
resumeStream();
|
||||
resumeStreamRef.current();
|
||||
}, delay);
|
||||
}
|
||||
|
||||
@@ -260,6 +260,14 @@ export function useCopilotStream({
|
||||
},
|
||||
});
|
||||
|
||||
// Keep stable refs to sdkStop and resumeStream so that async callbacks
|
||||
// (session-switch cleanup, wake re-sync, reconnect timer) always call the
|
||||
// latest version without stale-closure bugs.
|
||||
const sdkStopRef = useRef(sdkStop);
|
||||
sdkStopRef.current = sdkStop;
|
||||
const resumeStreamRef = useRef(resumeStream);
|
||||
resumeStreamRef.current = resumeStream;
|
||||
|
||||
// Wrap sdkSendMessage to guard against re-sending the user message during a
|
||||
// reconnect cycle. If the session already has the message (i.e. we are in a
|
||||
// reconnect/resume flow), only GET-resume is safe — never re-POST.
|
||||
@@ -386,7 +394,7 @@ export function useCopilotStream({
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
await resumeStream();
|
||||
await resumeStreamRef.current();
|
||||
}
|
||||
// If !backendActive, the refetch will update hydratedMessages via
|
||||
// React Query, and the hydration effect below will merge them in.
|
||||
@@ -409,7 +417,7 @@ export function useCopilotStream({
|
||||
return () => {
|
||||
document.removeEventListener("visibilitychange", onVisibilityChange);
|
||||
};
|
||||
}, [refetchSession, setMessages, resumeStream]);
|
||||
}, [refetchSession, setMessages]);
|
||||
|
||||
// Hydrate messages from REST API when not actively streaming
|
||||
useEffect(() => {
|
||||
@@ -425,8 +433,34 @@ export function useCopilotStream({
|
||||
// Track resume state per session
|
||||
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
|
||||
|
||||
// Clean up reconnect state on session switch
|
||||
// Clean up reconnect state on session switch.
|
||||
// Abort the old stream's in-flight fetch and tell the backend to release
|
||||
// its XREAD listeners immediately (fire-and-forget).
|
||||
const prevStreamSessionRef = useRef(sessionId);
|
||||
useEffect(() => {
|
||||
const prevSid = prevStreamSessionRef.current;
|
||||
prevStreamSessionRef.current = sessionId;
|
||||
|
||||
const isSwitching = Boolean(prevSid && prevSid !== sessionId);
|
||||
if (isSwitching) {
|
||||
// Mark BEFORE stopping so the old stream's async onError (which fires
|
||||
// after the abort) sees the flag and short-circuits the reconnect path.
|
||||
// Without this, the AbortError can queue a reconnect against the new
|
||||
// session's `sessionId` (captured in the fresh onError closure).
|
||||
isUserStoppingRef.current = true;
|
||||
sdkStopRef.current();
|
||||
disconnectSessionStream(prevSid!);
|
||||
// Schedule the reset as a task (not a microtask) so it runs AFTER the
|
||||
// aborted fetch's onError has fired — otherwise the new session would
|
||||
// be stuck with the "user stopping" flag set, preventing auto-resume
|
||||
// when hydration detects an active backend stream.
|
||||
setTimeout(() => {
|
||||
isUserStoppingRef.current = false;
|
||||
}, 0);
|
||||
} else {
|
||||
isUserStoppingRef.current = false;
|
||||
}
|
||||
|
||||
clearTimeout(reconnectTimerRef.current);
|
||||
reconnectTimerRef.current = undefined;
|
||||
reconnectAttemptsRef.current = 0;
|
||||
@@ -434,7 +468,6 @@ export function useCopilotStream({
|
||||
setIsReconnectScheduled(false);
|
||||
setRateLimitMessage(null);
|
||||
hasShownDisconnectToast.current = false;
|
||||
isUserStoppingRef.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
setReconnectExhausted(false);
|
||||
setIsSyncing(false);
|
||||
@@ -501,15 +534,8 @@ export function useCopilotStream({
|
||||
return prev;
|
||||
});
|
||||
|
||||
resumeStream();
|
||||
}, [
|
||||
sessionId,
|
||||
hasActiveStream,
|
||||
hydratedMessages,
|
||||
status,
|
||||
resumeStream,
|
||||
setMessages,
|
||||
]);
|
||||
resumeStreamRef.current();
|
||||
}, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]);
|
||||
|
||||
// Clear messages when session is null
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1633,6 +1633,35 @@
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}/stream": {
|
||||
"delete": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Disconnect Session Stream",
|
||||
"description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.",
|
||||
"operationId": "deleteV2DisconnectSessionStream",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": { "description": "Successful Response" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Resume Session Stream",
|
||||
|
||||
Reference in New Issue
Block a user