mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend/copilot): single-owner StreamFinish via mark_session_completed
mark_session_completed() is now the SINGLE place that publishes StreamFinish to the turn stream. Simplified API: - mark_session_completed(session_id) → completed - mark_session_completed(session_id, error_message='...') → failed Flow: set status (Lua CAS) → StreamError if failed → StreamFinish. The processor intercepts StreamFinish from generators, calls mark_session_completed instead. Removed _mark_task_failed (redundant). Removed cleanup_turn_stream (streams have TTL, eager deletion raced with _stream_listener xread).
This commit is contained in:
@@ -12,7 +12,7 @@ import time
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.copilot.response_model import StreamFinish, StreamFinishStep
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
@@ -238,25 +238,23 @@ class CoPilotProcessor:
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.turn_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.turn_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.turn_id, StreamFinish())
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, status="failed"
|
||||
entry.session_id,
|
||||
error_message="Operation cancelled",
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, chunk)
|
||||
except Exception as e:
|
||||
@@ -265,39 +263,25 @@ class CoPilotProcessor:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark session as completed
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, status="completed"
|
||||
)
|
||||
await stream_registry.mark_session_completed(entry.session_id)
|
||||
log.info("Task completed successfully")
|
||||
|
||||
if entry.turn_id:
|
||||
await stream_registry.cleanup_turn_stream(entry.turn_id)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id,
|
||||
status="failed",
|
||||
error_message="Task was cancelled",
|
||||
entry.session_id, error_message="Task was cancelled"
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.session_id, str(e), entry.turn_id)
|
||||
try:
|
||||
await stream_registry.publish_chunk(entry.turn_id, StreamFinishStep())
|
||||
await stream_registry.mark_session_completed(
|
||||
entry.session_id, error_message=str(e)
|
||||
)
|
||||
except Exception as mark_err:
|
||||
logger.error(
|
||||
f"Failed to mark session {entry.session_id} as failed: {mark_err}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(
|
||||
self, session_id: str, error_message: str, turn_id: str = ""
|
||||
):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
turn_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(turn_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(turn_id, StreamFinish())
|
||||
await stream_registry.mark_session_completed(session_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark session {session_id} as failed: {e}")
|
||||
|
||||
@@ -685,27 +685,28 @@ async def _stream_listener(
|
||||
|
||||
async def mark_session_completed(
|
||||
session_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
*,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Mark a session as completed and publish finish event.
|
||||
"""Mark a session as completed, then publish StreamFinish.
|
||||
|
||||
This is the SINGLE place that publishes StreamFinish to the turn stream.
|
||||
Services must NOT yield StreamFinish themselves — the processor intercepts
|
||||
it and calls this function instead, ensuring status is set before
|
||||
StreamFinish reaches the frontend.
|
||||
|
||||
This is idempotent - calling multiple times with the same session_id is safe.
|
||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||
Idempotent — calling multiple times is safe (returns False on no-op).
|
||||
|
||||
Args:
|
||||
session_id: Session ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
error_message: If provided and status="failed", publish a StreamError
|
||||
before StreamFinish so connected clients see why the session ended.
|
||||
If not provided, no StreamError is published (caller should publish
|
||||
manually if needed to avoid duplicates).
|
||||
error_message: If provided, marks as "failed" and publishes a
|
||||
StreamError before StreamFinish. Otherwise marks as "completed".
|
||||
|
||||
Returns:
|
||||
True if session was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
status: Literal["completed", "failed"] = "failed" if error_message else "completed"
|
||||
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_session_meta_key(session_id)
|
||||
|
||||
@@ -714,19 +715,13 @@ async def mark_session_completed(
|
||||
turn_id = _parse_session_meta(meta, session_id).turn_id if meta else session_id
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||
result = await redis.eval(COMPLETE_SESSION_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Session {session_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# Publish error event before finish so connected clients know WHY the
|
||||
# session ended. Only publish if caller provided an explicit error message
|
||||
# to avoid duplicates with code paths that manually publish StreamError.
|
||||
# This is best-effort — if it fails, the StreamFinish still ensures
|
||||
# listeners clean up.
|
||||
if status == "failed" and error_message:
|
||||
if error_message:
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
@@ -734,13 +729,15 @@ async def mark_session_completed(
|
||||
f"Failed to publish error event for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
# Publish StreamFinish AFTER status is set to "completed"/"failed".
|
||||
# This is the SINGLE place that publishes StreamFinish — services and
|
||||
# the processor must NOT publish it themselves.
|
||||
try:
|
||||
await publish_chunk(turn_id, StreamFinish())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish finish event for session {session_id}: {e}. "
|
||||
"Listeners will detect completion via status polling."
|
||||
f"Failed to publish StreamFinish for session {session_id}: {e}. "
|
||||
"The _stream_listener will detect completion via status polling."
|
||||
)
|
||||
|
||||
# Clean up local session reference if exists
|
||||
@@ -852,7 +849,6 @@ async def get_active_session(
|
||||
)
|
||||
await mark_session_completed(
|
||||
session_id,
|
||||
status="failed",
|
||||
error_message=f"Session timed out after {age_seconds:.0f}s",
|
||||
)
|
||||
return None, "0-0"
|
||||
@@ -999,13 +995,3 @@ async def unsubscribe_from_session(
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from session {session_id}")
|
||||
|
||||
|
||||
async def cleanup_turn_stream(turn_id: str) -> None:
|
||||
"""Delete the per-turn Redis stream after the turn completes."""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_turn_stream_key(turn_id)
|
||||
await redis.delete(stream_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up turn stream {turn_id}: {e}")
|
||||
|
||||
@@ -276,32 +276,6 @@ export function useCopilotPage() {
|
||||
resumeStream();
|
||||
}, [sessionId, hasActiveStream, hydratedMessages, status, resumeStream]);
|
||||
|
||||
// Poll for task completion when streaming to detect stuck streams
|
||||
// This prevents UI from getting stuck if StreamFinish event is missed
|
||||
useEffect(() => {
|
||||
if (!sessionId) return;
|
||||
if (status !== "streaming" && status !== "submitted") return;
|
||||
|
||||
const pollInterval = setInterval(() => {
|
||||
// Invalidate session query to check if backend task has completed
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
});
|
||||
}, 3000); // Poll every 3 seconds
|
||||
|
||||
return () => clearInterval(pollInterval);
|
||||
}, [sessionId, status, queryClient]);
|
||||
|
||||
// If backend says no active stream but frontend thinks it's streaming, stop it
|
||||
useEffect(() => {
|
||||
if (!sessionId) return;
|
||||
if (!hasActiveStream && (status === "streaming" || status === "submitted")) {
|
||||
// Backend has completed but frontend is still streaming - force stop
|
||||
sdkStop();
|
||||
setMessages((prev) => resolveInProgressTools(prev, "completed"));
|
||||
}
|
||||
}, [hasActiveStream, status, sessionId, sdkStop, setMessages]);
|
||||
|
||||
// Clear messages when session is null
|
||||
useEffect(() => {
|
||||
if (!sessionId) setMessages([]);
|
||||
|
||||
Reference in New Issue
Block a user