mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 03:14:57 -05:00
fixing sse reconnection
This commit is contained in:
@@ -3,12 +3,16 @@
|
||||
This module provides a consumer that listens for completion notifications
|
||||
from external services (like Agent Generator) and triggers the appropriate
|
||||
stream registry and chat service updates.
|
||||
|
||||
The consumer initializes its own Prisma client to avoid async context issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import (
|
||||
@@ -57,12 +61,17 @@ class OperationCompleteMessage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from RabbitMQ."""
|
||||
"""Consumer for chat operation completion messages from RabbitMQ.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._rabbitmq: AsyncRabbitMQ | None = None
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
@@ -70,6 +79,9 @@ class ChatCompletionConsumer:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Don't initialize Prisma here - do it lazily on first message
|
||||
# to ensure it's in the same async context as the message handler
|
||||
|
||||
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
await self._rabbitmq.connect()
|
||||
|
||||
@@ -77,6 +89,15 @@ class ChatCompletionConsumer:
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info("Chat completion consumer started")
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
@@ -93,6 +114,11 @@ class ChatCompletionConsumer:
|
||||
await self._rabbitmq.disconnect()
|
||||
self._rabbitmq = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
@@ -144,7 +170,7 @@ class ChatCompletionConsumer:
|
||||
return
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a single completion message."""
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
@@ -153,23 +179,36 @@ class ChatCompletionConsumer:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Received completion for operation {message.operation_id} "
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
# Try to look up by task_id directly
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"Task not found for operation {message.operation_id} "
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
@@ -197,7 +236,7 @@ class ChatCompletionConsumer:
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# Update pending operation in database using our Prisma client
|
||||
result_str = (
|
||||
message.result
|
||||
if isinstance(message.result, str)
|
||||
@@ -207,26 +246,45 @@ class ChatCompletionConsumer:
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
try:
|
||||
prisma = await self._ensure_prisma()
|
||||
await prisma.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": result_str},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message for session {task.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed completion for task {task.task_id} "
|
||||
f"(operation {message.operation_id})"
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
async def _handle_failure(
|
||||
@@ -237,31 +295,44 @@ class ChatCompletionConsumer:
|
||||
"""Handle failed operation completion."""
|
||||
error_msg = message.error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
# Update pending operation with error using our Prisma client
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=message.error,
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
try:
|
||||
prisma = await self._ensure_prisma()
|
||||
await prisma.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": error_response.model_dump_json()},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Processed failure for task {task.task_id} "
|
||||
f"(operation {message.operation_id}): {error_msg}"
|
||||
f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
@@ -294,9 +365,6 @@ async def publish_operation_complete(
|
||||
) -> None:
|
||||
"""Publish an operation completion message.
|
||||
|
||||
This is a helper function for testing or for services that want to
|
||||
publish completion messages directly.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
|
||||
@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
@@ -16,7 +17,7 @@ from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -58,6 +59,13 @@ class CreateSessionResponse(BaseModel):
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
@@ -66,6 +74,7 @@ class SessionDetailResponse(BaseModel):
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -177,13 +186,14 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -191,10 +201,31 @@ async def get_session(
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
logger.info(f"[SSE-RECONNECT] Checking for active stream in session {session_id}")
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
|
||||
f"task_id={active_task.task_id}, status={active_task.status}, "
|
||||
f"last_message_id={last_message_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")
|
||||
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
f"roles={[m.get('role') for m in messages]}, "
|
||||
f"has_active_stream={active_stream_info is not None}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -203,6 +234,7 @@ async def get_session(
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -222,49 +254,136 @@ async def stream_chat_post(
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Created stream task for reconnection support: "
|
||||
f"task_id={task_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
# Write to Redis (subscribers will receive via pub/sub or polling)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
logger.info(
|
||||
"[SSE-RECONNECT] Background AI generation completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SSE-RECONNECT] Error in background AI generation for session "
|
||||
f"{session_id}: {e}"
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
logger.info(f"[SSE-RECONNECT] Started background AI generation task for {task_id}")
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.error(f"Failed to subscribe to task {task_id}")
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] SSE subscriber received finish for task {task_id}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
# Client disconnected - that's fine, background task continues
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] SSE client disconnected for task {task_id}, "
|
||||
f"background generation continues"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -409,6 +528,11 @@ async def stream_task(
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Client reconnecting to task stream: "
|
||||
f"task_id={task_id}, last_message_id={last_message_id}"
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
@@ -417,8 +541,15 @@ async def stream_task(
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.warning(
|
||||
f"[SSE-RECONNECT] Task not found or access denied: task_id={task_id}"
|
||||
)
|
||||
raise NotFoundError(f"Task {task_id} not found or access denied.")
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Successfully subscribed to task stream: task_id={task_id}"
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It supports:
|
||||
- Creating tasks with unique IDs for long-running operations
|
||||
- Publishing stream messages to both Redis Streams and in-memory queues
|
||||
- Subscribing to tasks with replay of missed messages
|
||||
- Looking up tasks by operation_id for webhook callbacks
|
||||
- Cross-pod real-time delivery via Redis pub/sub
|
||||
messages. It uses Redis for all state management (no in-memory state), making
|
||||
pods stateless and horizontally scalable.
|
||||
|
||||
Architecture:
|
||||
- Redis Stream: Persists all messages for replay
|
||||
- Redis Pub/Sub: Real-time delivery to subscribers
|
||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||
|
||||
Subscribers:
|
||||
1. Replay missed messages from Redis Stream
|
||||
2. Subscribe to pub/sub channel for live updates
|
||||
3. No in-memory state required on the subscribing pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -25,13 +31,10 @@ from .response_model import StreamBaseResponse, StreamFinish
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track active pub/sub listeners for cross-pod delivery
|
||||
_pubsub_listeners: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task."""
|
||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
@@ -41,22 +44,17 @@ class ActiveTask:
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
# Lock for atomic status checks and subscriber management
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
# Set of subscriber queues for fan-out
|
||||
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
|
||||
|
||||
|
||||
# Module-level registry for active tasks
|
||||
_active_tasks: dict[str, ActiveTask] = {}
|
||||
|
||||
# Redis key patterns
|
||||
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
|
||||
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
|
||||
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for real-time delivery
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
@@ -75,7 +73,7 @@ def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task cross-pod delivery."""
|
||||
"""Get Redis pub/sub channel for task real-time delivery."""
|
||||
return f"{TASK_PUBSUB_PREFIX}{task_id}"
|
||||
|
||||
|
||||
@@ -87,7 +85,7 @@ async def create_task(
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in memory and Redis.
|
||||
"""Create a new streaming task in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
@@ -98,7 +96,7 @@ async def create_task(
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
@@ -109,10 +107,7 @@ async def create_task(
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store in memory registry
|
||||
_active_tasks[task_id] = task
|
||||
|
||||
# Store metadata in Redis for durability
|
||||
# Store metadata in Redis
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
@@ -136,8 +131,7 @@ async def create_task(
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.info(
|
||||
f"Created streaming task {task_id} for operation {operation_id} "
|
||||
f"in session {session_id}"
|
||||
f"[SSE-RECONNECT] Created task {task_id} for session {session_id} in Redis"
|
||||
)
|
||||
|
||||
return task
|
||||
@@ -147,41 +141,26 @@ async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to the task's stream.
|
||||
"""Publish a chunk to Redis Stream and pub/sub channel.
|
||||
|
||||
Delivers to in-memory subscribers first (for real-time), then persists to
|
||||
Redis Stream (for replay). This order ensures live subscribers get messages
|
||||
even if Redis temporarily fails.
|
||||
All delivery is via Redis - no in-memory state.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
|
||||
Redis persistence failed
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
# Deliver to in-memory subscribers FIRST for real-time updates
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, dropping chunk"
|
||||
)
|
||||
|
||||
# Then persist to Redis Stream for replay (with error handling)
|
||||
message_id = "0-0"
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
|
||||
# Add to Redis Stream with auto-generated ID
|
||||
# The ID format is "timestamp-sequence" which gives us ordering
|
||||
# Write to Redis Stream for persistence/replay
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
@@ -189,14 +168,13 @@ async def publish_chunk(
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Publish to pub/sub for cross-pod real-time delivery
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
# Publish to pub/sub for real-time delivery
|
||||
await redis.publish(pubsub_channel, chunk_json)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist chunk to Redis for task {task_id}: {e}",
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@@ -210,6 +188,8 @@ async def subscribe_to_task(
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
@@ -219,102 +199,23 @@ async def subscribe_to_task(
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
# Check in-memory first
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task.user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Create a new queue for this subscriber
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
|
||||
# Replay from Redis Stream
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Track the last message ID we've seen for gap detection
|
||||
replay_last_id = last_message_id
|
||||
|
||||
# Read all messages from stream starting after last_message_id
|
||||
# xread returns messages with ID > last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
# messages format: [[stream_name, [(id, {data: json}), ...]]]
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
# Track the last message ID we've processed
|
||||
replay_last_id = (
|
||||
msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
)
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
# Reconstruct the appropriate response type
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Atomically check status and register subscriber under lock
|
||||
# This prevents race condition where task completes between check and subscribe
|
||||
should_start_pubsub = False
|
||||
async with task.lock:
|
||||
if task.status == "running":
|
||||
# Register this subscriber for live updates
|
||||
task.subscribers.add(subscriber_queue)
|
||||
# Start pub/sub listener if this is the first subscriber
|
||||
should_start_pubsub = len(task.subscribers) == 1
|
||||
logger.debug(
|
||||
f"Registered subscriber for task {task_id}, "
|
||||
f"total subscribers: {len(task.subscribers)}"
|
||||
)
|
||||
else:
|
||||
# Task is done, add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
# After registering, do a second read to catch any messages published
|
||||
# between the first read and registration (closes the race window)
|
||||
if task.status == "running":
|
||||
gap_messages = await redis.xread(
|
||||
{stream_key: replay_last_id}, block=0, count=1000
|
||||
)
|
||||
if gap_messages:
|
||||
for _stream_name, stream_messages in gap_messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay gap message: {e}")
|
||||
|
||||
# Start pub/sub listener outside the lock to avoid deadlocks
|
||||
if should_start_pubsub:
|
||||
await start_pubsub_listener(task_id)
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
# Try to load from Redis if not in memory
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.warning(f"Task {task_id} not found in memory or Redis")
|
||||
logger.warning(f"[SSE-RECONNECT] Task {task_id} not found in Redis")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}"
|
||||
)
|
||||
|
||||
# Validate ownership
|
||||
task_user_id = meta.get(b"user_id", b"").decode() or None
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
@@ -322,79 +223,158 @@ async def subscribe_to_task(
|
||||
)
|
||||
return None
|
||||
|
||||
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
|
||||
subscriber_queue = asyncio.Queue()
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Read all messages starting after last_message_id
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Add finish marker since task is not active
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id}: replayed {replayed_count} messages "
|
||||
f"(last_id={replay_last_id})"
|
||||
)
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
|
||||
)
|
||||
asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} is {task_status}, adding finish marker"
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||
when messages are published during the gap between replay and subscription.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
current_id = last_replayed_id
|
||||
|
||||
logger.debug(
|
||||
f"[SSE-RECONNECT] Stream listener started for task {task_id}, "
|
||||
f"from ID {current_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
if status and status != "running":
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} no longer running "
|
||||
f"(status={status}), stopping listener"
|
||||
)
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
break
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} finished "
|
||||
"via stream"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[SSE-RECONNECT] Stream listener cancelled for task {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> None:
|
||||
"""Mark a task as completed and publish final event.
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Acquire lock to prevent new subscribers during completion
|
||||
async with task.lock:
|
||||
task.status = status
|
||||
# Send finish event directly to all current subscribers
|
||||
finish_event = StreamFinish()
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(finish_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id} during completion"
|
||||
)
|
||||
# Clear subscribers since task is done
|
||||
task.subscribers.clear()
|
||||
|
||||
# Stop pub/sub listener since task is done
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Remove from active tasks after a short delay to allow subscribers to finish
|
||||
async def _cleanup():
|
||||
await asyncio.sleep(5)
|
||||
_active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up task {task_id} from memory")
|
||||
|
||||
asyncio.create_task(_cleanup())
|
||||
# Publish finish event (goes to Redis Stream + pub/sub)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Update Redis metadata
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
await redis.hset(meta_key, "status", status) # type: ignore[misc]
|
||||
|
||||
logger.info(f"Marked task {task_id} as {status}")
|
||||
# Clean up local task reference if exists
|
||||
_local_tasks.pop(task_id, None)
|
||||
|
||||
logger.info(f"[SSE-RECONNECT] Marked task {task_id} as {status}")
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
@@ -408,43 +388,26 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
for task in _active_tasks.values():
|
||||
if task.operation_id == operation_id:
|
||||
return task
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if task_id:
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
# Check if task is in memory
|
||||
if task_id_str in _active_tasks:
|
||||
return _active_tasks[task_id_str]
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] find_task_by_operation_id: "
|
||||
f"op_key={op_key}, task_id_from_redis={task_id!r}"
|
||||
)
|
||||
|
||||
# Load metadata from Redis
|
||||
meta_key = _get_task_meta_key(task_id_str)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
if not task_id:
|
||||
logger.info(f"[SSE-RECONNECT] No task_id found for operation {operation_id}")
|
||||
return None
|
||||
|
||||
if meta:
|
||||
# Reconstruct task object (not fully active, but has metadata)
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=operation_id,
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore
|
||||
)
|
||||
|
||||
return None
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
logger.info(f"[SSE-RECONNECT] Looking up task by task_id={task_id_str}")
|
||||
return await get_task(task_id_str)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID.
|
||||
"""Get a task by its ID from Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
@@ -452,27 +415,127 @@ async def get_task(task_id: str) -> ActiveTask | None:
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
if task_id in _active_tasks:
|
||||
return _active_tasks[task_id]
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if meta:
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=meta.get(b"operation_id", b"").decode(),
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] get_task: meta_key={meta_key}, "
|
||||
f"meta_keys={list(meta.keys()) if meta else 'empty'}, "
|
||||
f"meta={meta}"
|
||||
)
|
||||
|
||||
if not meta:
|
||||
logger.info(f"[SSE-RECONNECT] No metadata found for task {task_id}")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task = ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] get_task returning: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, operation_id={task.operation_id}"
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
async def get_active_task_for_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[ActiveTask | None, str]:
|
||||
"""Get the active (running) task for a session, if any.
|
||||
|
||||
Scans Redis for tasks matching the session_id with status="running".
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
user_id: User ID for ownership validation (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||
"""
|
||||
logger.info(f"[SSE-RECONNECT] Looking for active task for session {session_id}")
|
||||
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Scan Redis for task metadata keys
|
||||
cursor = 0
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{TASK_META_PREFIX}*", count=100
|
||||
)
|
||||
|
||||
return None
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task_session_id = meta.get("session_id", "")
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
task_id = meta.get("task_id", "")
|
||||
|
||||
# Log tasks found for this session
|
||||
if task_session_id == session_id:
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Found task for session: "
|
||||
f"task_id={task_id}, status={task_status}"
|
||||
)
|
||||
|
||||
if task_session_id == session_id and task_status == "running":
|
||||
# Validate ownership
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.info(f"[SSE-RECONNECT] Task {task_id} ownership mismatch")
|
||||
continue
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
try:
|
||||
messages = await redis.xrevrange(stream_key, count=1)
|
||||
if messages:
|
||||
msg_id = messages[0][0]
|
||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get last message ID: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Found active task: task_id={task_id}, "
|
||||
f"last_message_id={last_id}"
|
||||
)
|
||||
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=task_session_id,
|
||||
user_id=task_user_id,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status="running",
|
||||
),
|
||||
last_id,
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] No active task found for session {session_id} "
|
||||
f"(checked {tasks_checked} tasks)"
|
||||
)
|
||||
return None, "0-0"
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
@@ -533,116 +596,30 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Associate an asyncio.Task with an ActiveTask.
|
||||
"""Track the asyncio.Task for a task (local reference only).
|
||||
|
||||
This is just for cleanup purposes - the task state is in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to associate
|
||||
asyncio_task: The asyncio Task to track
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
task.asyncio_task = asyncio_task
|
||||
_local_tasks[task_id] = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Unsubscribe a queue from a task's stream.
|
||||
"""Clean up when a subscriber disconnects.
|
||||
|
||||
Should be called when a client disconnects to clean up resources.
|
||||
Also stops the pub/sub listener if there are no more local subscribers.
|
||||
With Redis-based pub/sub, there's no explicit unsubscription needed.
|
||||
The pub/sub listener task will be garbage collected when the subscriber
|
||||
stops reading from the queue.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to unsubscribe from
|
||||
subscriber_queue: The queue to remove from subscribers
|
||||
task_id: Task ID
|
||||
subscriber_queue: The subscriber's queue (unused, kept for API compat)
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
task.subscribers.discard(subscriber_queue)
|
||||
remaining = len(task.subscribers)
|
||||
logger.debug(
|
||||
f"Unsubscribed from task {task_id}, "
|
||||
f"remaining subscribers: {remaining}"
|
||||
)
|
||||
# Stop pub/sub listener if no more local subscribers
|
||||
if remaining == 0:
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
|
||||
async def start_pubsub_listener(task_id: str) -> None:
|
||||
"""Start listening to Redis pub/sub for cross-pod delivery.
|
||||
|
||||
This enables real-time updates when another pod publishes chunks for a task
|
||||
that has local subscribers on this pod.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
"""
|
||||
if task_id in _pubsub_listeners:
|
||||
return # Already listening
|
||||
|
||||
task = _active_tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
async def _listener():
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pubsub = redis.pubsub()
|
||||
channel = _get_task_pubsub_channel(task_id)
|
||||
await pubsub.subscribe(channel)
|
||||
logger.debug(f"Started pub/sub listener for task {task_id}")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(message["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
# Deliver to local subscribers
|
||||
local_task = _active_tasks.get(task_id)
|
||||
if local_task:
|
||||
async with local_task.lock:
|
||||
for queue in local_task.subscribers:
|
||||
try:
|
||||
queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
# Stop listening if this was a finish event
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing pub/sub message: {e}")
|
||||
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
|
||||
finally:
|
||||
_pubsub_listeners.pop(task_id, None)
|
||||
logger.debug(f"Stopped pub/sub listener for task {task_id}")
|
||||
|
||||
listener_task = asyncio.create_task(_listener())
|
||||
_pubsub_listeners[task_id] = listener_task
|
||||
|
||||
|
||||
async def stop_pubsub_listener(task_id: str) -> None:
|
||||
"""Stop the pub/sub listener for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to stop listening for
|
||||
"""
|
||||
listener = _pubsub_listeners.pop(task_id, None)
|
||||
if listener and not listener.done():
|
||||
listener.cancel()
|
||||
try:
|
||||
await listener
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Cancelled pub/sub listener for task {task_id}")
|
||||
# No-op - pub/sub listener cleans up automatically
|
||||
logger.debug(f"[SSE-RECONNECT] Subscriber disconnected from task {task_id}")
|
||||
|
||||
@@ -1079,7 +1079,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1214,7 +1214,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -6313,6 +6313,16 @@
|
||||
"title": "AccuracyTrendsResponse",
|
||||
"description": "Response model for accuracy trends and alerts."
|
||||
},
|
||||
"ActiveStreamInfo": {
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "title": "Task Id" },
|
||||
"last_message_id": { "type": "string", "title": "Last Message Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["task_id", "last_message_id"],
|
||||
"title": "ActiveStreamInfo",
|
||||
"description": "Information about an active stream for reconnection."
|
||||
},
|
||||
"AddUserCreditsResponse": {
|
||||
"properties": {
|
||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||
@@ -9808,6 +9818,12 @@
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
"title": "Messages"
|
||||
},
|
||||
"active_stream": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -25,6 +25,7 @@ export function Chat({
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const {
|
||||
session,
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
@@ -36,6 +37,21 @@ export function Chat({
|
||||
startPollingForOperation,
|
||||
} = useChat({ urlSessionId });
|
||||
|
||||
// Extract active stream info for reconnection
|
||||
const activeStream = (session as { active_stream?: { task_id: string; last_message_id: string } })?.active_stream;
|
||||
|
||||
// Debug logging for SSE reconnection
|
||||
if (session) {
|
||||
console.info("[SSE-RECONNECT] Session loaded:", {
|
||||
sessionId,
|
||||
hasActiveStream: !!activeStream,
|
||||
activeStream: activeStream ? {
|
||||
taskId: activeStream.task_id,
|
||||
lastMessageId: activeStream.last_message_id,
|
||||
} : null,
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!onSessionNotFound) return;
|
||||
if (!urlSessionId) return;
|
||||
@@ -83,6 +99,10 @@ export function Chat({
|
||||
className="flex-1"
|
||||
onStreamingChange={onStreamingChange}
|
||||
onOperationStarted={startPollingForOperation}
|
||||
activeStream={activeStream ? {
|
||||
taskId: activeStream.task_id,
|
||||
lastMessageId: activeStream.last_message_id,
|
||||
} : undefined}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
@@ -391,6 +391,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
lastMessageId = "0-0", // Redis Stream ID format
|
||||
onChunk,
|
||||
) {
|
||||
console.info("[SSE-RECONNECT] reconnectToTask called:", {
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId,
|
||||
});
|
||||
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
@@ -435,8 +441,14 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
console.info("[SSE-RECONNECT] Starting executeTaskReconnect...");
|
||||
try {
|
||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||
console.info("[SSE-RECONNECT] executeTaskReconnect completed:", {
|
||||
sessionId,
|
||||
taskId,
|
||||
streamStatus: stream.status,
|
||||
});
|
||||
} finally {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
@@ -468,9 +480,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
// Clear active task on completion
|
||||
const taskState = get();
|
||||
const newActiveTasks = new Map(taskState.activeTasks);
|
||||
const hadActiveTask = newActiveTasks.has(sessionId);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
if (hadActiveTask) {
|
||||
console.info(
|
||||
`[ChatStore] Cleared active task for session ${sessionId} ` +
|
||||
`(stream status: ${stream.status})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "stream_start"
|
||||
| "text_chunk"
|
||||
| "text_ended"
|
||||
| "tool_call"
|
||||
@@ -15,6 +16,7 @@ export interface StreamChunk {
|
||||
| "error"
|
||||
| "usage"
|
||||
| "stream_end";
|
||||
taskId?: string; // Task ID for SSE reconnection
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
@@ -41,7 +43,7 @@ export interface StreamChunk {
|
||||
}
|
||||
|
||||
export type VercelStreamChunk =
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "start"; messageId: string; taskId?: string }
|
||||
| { type: "finish" }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
|
||||
@@ -17,6 +17,11 @@ export interface ChatContainerProps {
|
||||
className?: string;
|
||||
onStreamingChange?: (isStreaming: boolean) => void;
|
||||
onOperationStarted?: () => void;
|
||||
/** Active stream info from the server for reconnection */
|
||||
activeStream?: {
|
||||
taskId: string;
|
||||
lastMessageId: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
@@ -26,6 +31,7 @@ export function ChatContainer({
|
||||
className,
|
||||
onStreamingChange,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
}: ChatContainerProps) {
|
||||
const {
|
||||
messages,
|
||||
@@ -41,6 +47,7 @@ export function ChatContainer({
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -34,6 +34,22 @@ export function createStreamEventDispatcher(
|
||||
}
|
||||
|
||||
switch (chunk.type) {
|
||||
case "stream_start":
|
||||
// Store task ID for SSE reconnection
|
||||
if (chunk.taskId && deps.onActiveTaskStarted) {
|
||||
console.info("[ChatStream] Stream started with task ID:", {
|
||||
sessionId: deps.sessionId,
|
||||
taskId: chunk.taskId,
|
||||
});
|
||||
deps.onActiveTaskStarted({
|
||||
taskId: chunk.taskId,
|
||||
operationId: chunk.taskId, // Use taskId as operationId for chat streams
|
||||
toolName: "chat",
|
||||
toolCallId: "chat_stream",
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "text_chunk":
|
||||
handleTextChunk(chunk, deps);
|
||||
break;
|
||||
@@ -56,7 +72,8 @@ export function createStreamEventDispatcher(
|
||||
break;
|
||||
|
||||
case "stream_end":
|
||||
console.info("[ChatStream] Stream ended:", {
|
||||
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
|
||||
console.info("[SSE-RECONNECT] Stream ended:", {
|
||||
sessionId: deps.sessionId,
|
||||
hasResponse: deps.hasResponseRef.current,
|
||||
chunkCount: deps.streamingChunksRef.current.length,
|
||||
|
||||
@@ -221,8 +221,10 @@ export function handleStreamEnd(
|
||||
_chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.info("[SSE-RECONNECT] handleStreamEnd called, resetting streaming state");
|
||||
const completedContent = deps.streamingChunksRef.current.join("");
|
||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||
console.info("[SSE-RECONNECT] No content received, adding placeholder message");
|
||||
deps.setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
@@ -261,10 +263,14 @@ export function handleStreamEnd(
|
||||
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
console.error("[ChatStream] Stream error:", errorMessage, {
|
||||
sessionId: deps.sessionId,
|
||||
chunk,
|
||||
});
|
||||
if (isRegionBlockedError(chunk)) {
|
||||
deps.setIsRegionBlockedModalOpen(true);
|
||||
}
|
||||
console.info("[ChatStream] Resetting streaming state due to error");
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
|
||||
@@ -41,6 +41,11 @@ interface Args {
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
onOperationStarted?: () => void;
|
||||
/** Active stream info from the server for reconnection */
|
||||
activeStream?: {
|
||||
taskId: string;
|
||||
lastMessageId: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
@@ -48,6 +53,7 @@ export function useChatContainer({
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
}: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
@@ -69,6 +75,8 @@ export function useChatContainer({
|
||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
// Track whether we've already connected to this activeStream to avoid duplicate connections
|
||||
const connectedActiveStreamRef = useRef<string | null>(null);
|
||||
|
||||
// Callback to store active task info for SSE reconnection
|
||||
function handleActiveTaskStarted(taskInfo: {
|
||||
@@ -88,25 +96,131 @@ export function useChatContainer({
|
||||
|
||||
useEffect(
|
||||
function handleSessionChange() {
|
||||
if (sessionId === previousSessionIdRef.current) return;
|
||||
const isSessionChange = sessionId !== previousSessionIdRef.current;
|
||||
|
||||
const prevSession = previousSessionIdRef.current;
|
||||
if (prevSession) {
|
||||
stopStreaming(prevSession);
|
||||
console.info("[SSE-RECONNECT] handleSessionChange effect running:", {
|
||||
sessionId,
|
||||
previousSessionId: previousSessionIdRef.current,
|
||||
isSessionChange,
|
||||
hasActiveStream: !!activeStream,
|
||||
activeStreamTaskId: activeStream?.taskId,
|
||||
connectedActiveStream: connectedActiveStreamRef.current,
|
||||
});
|
||||
|
||||
// Handle session change - reset state
|
||||
if (isSessionChange) {
|
||||
console.info("[SSE-RECONNECT] Session changed, resetting state");
|
||||
const prevSession = previousSessionIdRef.current;
|
||||
if (prevSession) {
|
||||
stopStreaming(prevSession);
|
||||
}
|
||||
previousSessionIdRef.current = sessionId;
|
||||
connectedActiveStreamRef.current = null; // Reset connected stream tracker
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
}
|
||||
previousSessionIdRef.current = sessionId;
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
if (!sessionId) return;
|
||||
if (!sessionId) {
|
||||
console.info("[SSE-RECONNECT] No sessionId, skipping reconnection check");
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if there's an active task for this session that we should reconnect to
|
||||
// Priority 1: Check if server told us there's an active stream (most authoritative)
|
||||
// Also handles the case where activeStream arrives after initial session load
|
||||
if (activeStream) {
|
||||
// Skip if we've already connected to this exact stream
|
||||
// Check and set immediately to prevent race conditions from effect re-runs
|
||||
const streamKey = `${sessionId}:${activeStream.taskId}`;
|
||||
if (connectedActiveStreamRef.current === streamKey) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Already connected to this stream, skipping:",
|
||||
{ streamKey },
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Also skip if there's already an active stream for this session in the store
|
||||
// (handles case where effect re-runs due to activeStreams state change)
|
||||
const existingStream = activeStreams.get(sessionId);
|
||||
if (existingStream && existingStream.status === "streaming") {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Active stream already exists in store, skipping:",
|
||||
{ sessionId, status: existingStream.status },
|
||||
);
|
||||
connectedActiveStreamRef.current = streamKey;
|
||||
return;
|
||||
}
|
||||
|
||||
// Set immediately after check to prevent race conditions
|
||||
connectedActiveStreamRef.current = streamKey;
|
||||
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Server reports active stream, initiating reconnection:",
|
||||
{
|
||||
sessionId,
|
||||
taskId: activeStream.taskId,
|
||||
lastMessageId: activeStream.lastMessageId,
|
||||
streamKey,
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Store this as the active task for future reconnects
|
||||
setActiveTask(sessionId, {
|
||||
taskId: activeStream.taskId,
|
||||
operationId: activeStream.taskId,
|
||||
toolName: "chat",
|
||||
lastMessageId: activeStream.lastMessageId,
|
||||
});
|
||||
// Reconnect to the task stream
|
||||
console.info("[SSE-RECONNECT] Calling reconnectToTask...");
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeStream.taskId,
|
||||
activeStream.lastMessageId,
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check localStorage/in-memory on session change, not on every render
|
||||
if (!isSessionChange) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] No active stream and not a session change, skipping fallbacks",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Priority 2: Check localStorage for active task (client-side state)
|
||||
console.info("[SSE-RECONNECT] Checking localStorage for active task...");
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
if (activeTask) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Found active task in localStorage, attempting reconnect:",
|
||||
{
|
||||
sessionId,
|
||||
taskId: activeTask.taskId,
|
||||
lastMessageId: activeTask.lastMessageId,
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
@@ -122,6 +236,7 @@ export function useChatContainer({
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Reconnect to the task stream
|
||||
console.info("[SSE-RECONNECT] Calling reconnectToTask from localStorage...");
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
@@ -129,11 +244,20 @@ export function useChatContainer({
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
console.info("[SSE-RECONNECT] No active task in localStorage");
|
||||
}
|
||||
|
||||
// Otherwise check for an in-memory active stream
|
||||
const activeStream = activeStreams.get(sessionId);
|
||||
if (!activeStream || activeStream.status !== "streaming") return;
|
||||
// Priority 3: Check for an in-memory active stream (same-tab scenario)
|
||||
console.info("[SSE-RECONNECT] Checking in-memory active streams...");
|
||||
const inMemoryStream = activeStreams.get(sessionId);
|
||||
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
|
||||
console.info("[SSE-RECONNECT] No in-memory active stream found:", {
|
||||
hasStream: !!inMemoryStream,
|
||||
status: inMemoryStream?.status,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
@@ -160,6 +284,8 @@ export function useChatContainer({
|
||||
onOperationStarted,
|
||||
getActiveTask,
|
||||
reconnectToTask,
|
||||
activeStream,
|
||||
setActiveTask,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
@@ -167,8 +167,15 @@ export async function executeTaskReconnect(
|
||||
): Promise<void> {
|
||||
const { abortController } = stream;
|
||||
|
||||
console.info("[SSE-RECONNECT] executeTaskReconnect starting:", {
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
|
||||
try {
|
||||
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
console.info("[SSE-RECONNECT] Fetching task stream:", { url });
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
@@ -178,15 +185,33 @@ export async function executeTaskReconnect(
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
console.info("[SSE-RECONNECT] Task stream response:", {
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
console.error("[SSE-RECONNECT] Task stream error response:", {
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
// Don't retry on 404 (task not found) or 403 (access denied) - these are permanent errors
|
||||
const isPermanentError =
|
||||
response.status === 404 || response.status === 403;
|
||||
const error = new Error(errorText || `HTTP ${response.status}`);
|
||||
(error as Error & { status?: number }).status = response.status;
|
||||
(error as Error & { isPermanent?: boolean }).isPermanent =
|
||||
isPermanentError;
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
console.info("[SSE-RECONNECT] Task stream connected, reading chunks...");
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
@@ -195,6 +220,7 @@ export async function executeTaskReconnect(
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
console.info("[SSE-RECONNECT] Task stream reader done (connection closed)");
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -208,6 +234,7 @@ export async function executeTaskReconnect(
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
console.info("[SSE-RECONNECT] Task stream received [DONE] signal");
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -220,14 +247,24 @@ export async function executeTaskReconnect(
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
// Log first few chunks for debugging
|
||||
if (stream.chunks.length < 3) {
|
||||
console.info("[SSE-RECONNECT] Task stream chunk received:", {
|
||||
type: chunk.type,
|
||||
chunkIndex: stream.chunks.length,
|
||||
});
|
||||
}
|
||||
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
console.info("[SSE-RECONNECT] Task stream completed via stream_end chunk");
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
console.error("[SSE-RECONNECT] Task stream error chunk:", chunk);
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
@@ -250,17 +287,35 @@ export async function executeTaskReconnect(
|
||||
return;
|
||||
}
|
||||
|
||||
if (retryCount < MAX_RETRIES) {
|
||||
// Check if this is a permanent error (404/403) that shouldn't be retried
|
||||
const isPermanentError =
|
||||
err instanceof Error &&
|
||||
(err as Error & { isPermanent?: boolean }).isPermanent;
|
||||
|
||||
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1);
|
||||
return executeTaskReconnect(
|
||||
stream,
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount + 1,
|
||||
);
|
||||
}
|
||||
|
||||
// Log permanent errors differently for debugging
|
||||
if (isPermanentError) {
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect failed permanently (task not found or access denied): ${(err as Error).message}`,
|
||||
);
|
||||
}
|
||||
|
||||
stream.status = "error";
|
||||
stream.error = err instanceof Error ? err : new Error("Task reconnect failed");
|
||||
stream.error =
|
||||
err instanceof Error ? err : new Error("Task reconnect failed");
|
||||
notifySubscribers(stream, {
|
||||
type: "error",
|
||||
message: stream.error.message,
|
||||
|
||||
@@ -28,7 +28,8 @@ export function normalizeStreamChunk(
|
||||
|
||||
switch (chunk.type) {
|
||||
case "text-delta":
|
||||
return { type: "text_chunk", content: chunk.delta };
|
||||
// Backend sends "content", Vercel AI SDK sends "delta"
|
||||
return { type: "text_chunk", content: chunk.delta || chunk.content };
|
||||
case "text-end":
|
||||
return { type: "text_ended" };
|
||||
case "tool-input-available":
|
||||
@@ -63,6 +64,10 @@ export function normalizeStreamChunk(
|
||||
case "finish":
|
||||
return { type: "stream_end" };
|
||||
case "start":
|
||||
// Start event with optional taskId for reconnection
|
||||
return chunk.taskId
|
||||
? { type: "stream_start", taskId: chunk.taskId }
|
||||
: null;
|
||||
case "text-start":
|
||||
return null;
|
||||
case "tool-input-start":
|
||||
|
||||
Reference in New Issue
Block a user