Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into fix/copilot-model-toggle-styling

This commit is contained in:
majdyz
2026-04-15 17:58:50 +07:00
7 changed files with 313 additions and 18 deletions

View File

@@ -381,6 +381,31 @@ async def delete_session(
return Response(status_code=204)
@router.delete(
"/sessions/{session_id}/stream",
dependencies=[Security(auth.requires_user)],
status_code=204,
)
async def disconnect_session_stream(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""Disconnect all active SSE listeners for a session.
Called by the frontend when the user switches away from a chat so the
backend releases XREAD listeners immediately rather than waiting for
the 5-10 s timeout.
"""
session = await get_chat_session(session_id, user_id)
if not session:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
await stream_registry.disconnect_all_listeners(session_id)
return Response(status_code=204)
@router.patch(
"/sessions/{session_id}/title",
summary="Update session title",

View File

@@ -677,3 +677,48 @@ class TestStripInjectedContext:
result = _strip_injected_context(msg)
# Without a role, the helper short-circuits without touching content.
assert result["content"] == "hello"
# ─── DELETE /sessions/{id}/stream — disconnect listeners ──────────────
def test_disconnect_stream_returns_204_and_awaits_registry(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mock_session = MagicMock()
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=mock_session,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
return_value=2,
)
response = client.delete("/sessions/sess-1/stream")
assert response.status_code == 204
mock_disconnect.assert_awaited_once_with("sess-1")
def test_disconnect_stream_returns_404_when_session_missing(
mocker: pytest_mock.MockerFixture,
test_user_id: str,
) -> None:
mocker.patch(
"backend.api.features.chat.routes.get_chat_session",
new_callable=AsyncMock,
return_value=None,
)
mock_disconnect = mocker.patch(
"backend.api.features.chat.routes.stream_registry.disconnect_all_listeners",
new_callable=AsyncMock,
)
response = client.delete("/sessions/unknown-session/stream")
assert response.status_code == 404
mock_disconnect.assert_not_awaited()

View File

@@ -1149,3 +1149,50 @@ async def unsubscribe_from_session(
)
logger.debug(f"Successfully unsubscribed from session {session_id}")
async def disconnect_all_listeners(session_id: str) -> int:
"""Cancel every active listener task for *session_id*.
Called when the frontend switches away from a session and wants the
backend to release resources immediately rather than waiting for the
XREAD timeout.
Scope / limitations (best-effort optimisation, not a correctness primitive):
- Pod-local: ``_listener_sessions`` is in-memory. If the DELETE request
lands on a different worker than the one serving the SSE, no listener
is cancelled here — the SSE worker still releases on its XREAD timeout.
- Session-scoped (not subscriber-scoped): cancels every active listener
for the session on this pod. In the rare case a single user opens two
SSE connections to the same session on the same pod (e.g. two tabs),
both would be torn down. Cross-pod, subscriber-scoped cancellation
would require a Redis pub/sub fan-out with per-listener tokens; that
is not implemented here because the XREAD timeout already bounds the
worst case.
Returns the number of listener tasks that were cancelled.
"""
to_cancel: list[tuple[int, asyncio.Task]] = [
(qid, task)
for qid, (sid, task) in list(_listener_sessions.items())
if sid == session_id and not task.done()
]
for qid, task in to_cancel:
_listener_sessions.pop(qid, None)
task.cancel()
cancelled = 0
for _qid, task in to_cancel:
try:
await asyncio.wait_for(task, timeout=5.0)
except asyncio.CancelledError:
cancelled += 1
except asyncio.TimeoutError:
pass
except Exception as e:
logger.error(f"Error cancelling listener for session {session_id}: {e}")
if cancelled:
logger.info(f"Disconnected {cancelled} listener(s) for session {session_id}")
return cancelled

View File

@@ -0,0 +1,110 @@
"""Tests for disconnect_all_listeners in stream_registry."""
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot import stream_registry
@pytest.fixture(autouse=True)
def _clear_listener_sessions():
stream_registry._listener_sessions.clear()
yield
stream_registry._listener_sessions.clear()
async def _sleep_forever():
try:
await asyncio.sleep(3600)
except asyncio.CancelledError:
raise
@pytest.mark.asyncio
async def test_disconnect_all_listeners_cancels_matching_session():
task_a = asyncio.create_task(_sleep_forever())
task_b = asyncio.create_task(_sleep_forever())
task_other = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task_a)
stream_registry._listener_sessions[2] = ("sess-1", task_b)
stream_registry._listener_sessions[3] = ("sess-other", task_other)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 2
assert task_a.cancelled()
assert task_b.cancelled()
assert not task_other.done()
# Matching entries are removed, non-matching entries remain.
assert 1 not in stream_registry._listener_sessions
assert 2 not in stream_registry._listener_sessions
assert 3 in stream_registry._listener_sessions
finally:
task_other.cancel()
try:
await task_other
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_no_match_returns_zero():
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-other", task)
try:
cancelled = await stream_registry.disconnect_all_listeners("sess-missing")
assert cancelled == 0
assert not task.done()
assert 1 in stream_registry._listener_sessions
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_disconnect_all_listeners_skips_already_done_tasks():
async def _noop():
return None
done_task = asyncio.create_task(_noop())
await done_task
stream_registry._listener_sessions[1] = ("sess-1", done_task)
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
# Done tasks are filtered out before cancellation.
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_empty_registry():
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
@pytest.mark.asyncio
async def test_disconnect_all_listeners_timeout_not_counted():
"""Tasks that don't respond to cancellation (timeout) are not counted."""
task = asyncio.create_task(_sleep_forever())
stream_registry._listener_sessions[1] = ("sess-1", task)
with patch.object(
asyncio, "wait_for", new=AsyncMock(side_effect=asyncio.TimeoutError)
):
cancelled = await stream_registry.disconnect_all_listeners("sess-1")
assert cancelled == 0
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@@ -2,6 +2,8 @@ import { getSystemHeaders } from "@/lib/impersonation";
import { getWebSocketToken } from "@/lib/supabase/actions";
import type { UIMessage } from "ai";
import { deleteV2DisconnectSessionStream } from "@/app/api/__generated__/endpoints/chat/chat";
export const ORIGINAL_TITLE = "AutoGPT";
/**
@@ -154,7 +156,18 @@ export function shouldSuppressDuplicateSend(
}
/**
* Deduplicate messages by ID and by content fingerprint.
* Fire-and-forget: tell the backend to release XREAD listeners for a session.
*
* Called on session switch so the backend doesn't wait for its 5-10 s timeout
* before cleaning up. Failures are silently ignored — the backend will
* eventually clean up on its own.
*/
export function disconnectSessionStream(sessionId: string): void {
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
}
/**
* Deduplicate messages by ID and by consecutive content fingerprint.
*
* ID dedup catches exact duplicates within the same source.
* Content dedup uses a composite key of `role + preceding-user-message-id +

View File

@@ -17,6 +17,7 @@ import {
hasActiveBackendStream,
resolveInProgressTools,
getSendSuppressionReason,
disconnectSessionStream,
} from "./helpers";
import type { CopilotLlmModel, CopilotMode } from "./store";
@@ -153,16 +154,15 @@ export function useCopilotStream({
reconnectTimerRef.current = setTimeout(() => {
isReconnectScheduledRef.current = false;
setIsReconnectScheduled(false);
// Strip any stale in-progress assistant message before resuming.
// The backend replays from "0-0", so the partial message would
// otherwise sit alongside the fully-replayed version.
// Strip the stale in-progress assistant message before resuming
// the backend replays from "0-0", so keeping it would duplicate parts.
setMessages((prev) => {
if (prev.length > 0 && prev[prev.length - 1].role === "assistant") {
return prev.slice(0, -1);
}
return prev;
});
resumeStream();
resumeStreamRef.current();
}, delay);
}
@@ -260,6 +260,14 @@ export function useCopilotStream({
},
});
// Keep stable refs to sdkStop and resumeStream so that async callbacks
// (session-switch cleanup, wake re-sync, reconnect timer) always call the
// latest version without stale-closure bugs.
const sdkStopRef = useRef(sdkStop);
sdkStopRef.current = sdkStop;
const resumeStreamRef = useRef(resumeStream);
resumeStreamRef.current = resumeStream;
// Wrap sdkSendMessage to guard against re-sending the user message during a
// reconnect cycle. If the session already has the message (i.e. we are in a
// reconnect/resume flow), only GET-resume is safe — never re-POST.
@@ -386,7 +394,7 @@ export function useCopilotStream({
}
return prev;
});
await resumeStream();
await resumeStreamRef.current();
}
// If !backendActive, the refetch will update hydratedMessages via
// React Query, and the hydration effect below will merge them in.
@@ -409,7 +417,7 @@ export function useCopilotStream({
return () => {
document.removeEventListener("visibilitychange", onVisibilityChange);
};
}, [refetchSession, setMessages, resumeStream]);
}, [refetchSession, setMessages]);
// Hydrate messages from REST API when not actively streaming
useEffect(() => {
@@ -425,8 +433,34 @@ export function useCopilotStream({
// Track resume state per session
const hasResumedRef = useRef<Map<string, boolean>>(new Map());
// Clean up reconnect state on session switch
// Clean up reconnect state on session switch.
// Abort the old stream's in-flight fetch and tell the backend to release
// its XREAD listeners immediately (fire-and-forget).
const prevStreamSessionRef = useRef(sessionId);
useEffect(() => {
const prevSid = prevStreamSessionRef.current;
prevStreamSessionRef.current = sessionId;
const isSwitching = Boolean(prevSid && prevSid !== sessionId);
if (isSwitching) {
// Mark BEFORE stopping so the old stream's async onError (which fires
// after the abort) sees the flag and short-circuits the reconnect path.
// Without this, the AbortError can queue a reconnect against the new
// session's `sessionId` (captured in the fresh onError closure).
isUserStoppingRef.current = true;
sdkStopRef.current();
disconnectSessionStream(prevSid!);
// Schedule the reset as a task (not a microtask) so it runs AFTER the
// aborted fetch's onError has fired — otherwise the new session would
// be stuck with the "user stopping" flag set, preventing auto-resume
// when hydration detects an active backend stream.
setTimeout(() => {
isUserStoppingRef.current = false;
}, 0);
} else {
isUserStoppingRef.current = false;
}
clearTimeout(reconnectTimerRef.current);
reconnectTimerRef.current = undefined;
reconnectAttemptsRef.current = 0;
@@ -434,7 +468,6 @@ export function useCopilotStream({
setIsReconnectScheduled(false);
setRateLimitMessage(null);
hasShownDisconnectToast.current = false;
isUserStoppingRef.current = false;
lastSubmittedMsgRef.current = null;
setReconnectExhausted(false);
setIsSyncing(false);
@@ -501,15 +534,8 @@ export function useCopilotStream({
return prev;
});
resumeStream();
}, [
sessionId,
hasActiveStream,
hydratedMessages,
status,
resumeStream,
setMessages,
]);
resumeStreamRef.current();
}, [sessionId, hasActiveStream, hydratedMessages, status, setMessages]);
// Clear messages when session is null
useEffect(() => {

View File

@@ -1633,6 +1633,35 @@
}
},
"/api/chat/sessions/{session_id}/stream": {
"delete": {
"tags": ["v2", "chat", "chat"],
"summary": "Disconnect Session Stream",
"description": "Disconnect all active SSE listeners for a session.\n\nCalled by the frontend when the user switches away from a chat so the\nbackend releases XREAD listeners immediately rather than waiting for\nthe 5-10 s timeout.",
"operationId": "deleteV2DisconnectSessionStream",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"responses": {
"204": { "description": "Successful Response" },
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Resume Session Stream",