fixing sse reconnection

This commit is contained in:
Swifty
2026-02-02 14:08:20 +01:00
parent d1da7fe5da
commit 2cdfe90c56
14 changed files with 870 additions and 417 deletions

View File

@@ -3,12 +3,16 @@
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
The consumer initializes its own Prisma client to avoid async context issues.
"""
import asyncio
import logging
import os
import orjson
from prisma import Prisma
from pydantic import BaseModel
from backend.data.rabbitmq import (
@@ -57,12 +61,17 @@ class OperationCompleteMessage(BaseModel):
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from RabbitMQ."""
"""Consumer for chat operation completion messages from RabbitMQ.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
"""
def __init__(self):
self._rabbitmq: AsyncRabbitMQ | None = None
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
async def start(self) -> None:
"""Start the completion consumer."""
@@ -70,6 +79,9 @@ class ChatCompletionConsumer:
logger.warning("Completion consumer already running")
return
# Don't initialize Prisma here - do it lazily on first message
# to ensure it's in the same async context as the message handler
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
await self._rabbitmq.connect()
@@ -77,6 +89,15 @@ class ChatCompletionConsumer:
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info("Chat completion consumer started")
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
@@ -93,6 +114,11 @@ class ChatCompletionConsumer:
await self._rabbitmq.disconnect()
self._rabbitmq = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
@@ -144,7 +170,7 @@ class ChatCompletionConsumer:
return
async def _handle_message(self, body: bytes) -> None:
"""Handle a single completion message."""
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
@@ -153,23 +179,36 @@ class ChatCompletionConsumer:
return
logger.info(
f"Received completion for operation {message.operation_id} "
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
# Try to look up by task_id directly
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"Task not found for operation {message.operation_id} "
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
@@ -197,7 +236,7 @@ class ChatCompletionConsumer:
),
)
# Update pending operation in database
# Update pending operation in database using our Prisma client
result_str = (
message.result
if isinstance(message.result, str)
@@ -207,26 +246,45 @@ class ChatCompletionConsumer:
else '{"status": "completed"}'
)
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
try:
prisma = await self._ensure_prisma()
await prisma.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": result_str},
)
logger.info(
f"[COMPLETION] Updated tool message for session {task.session_id}"
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
)
# Generate LLM continuation with streaming
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await chat_service._mark_operation_completed(task.tool_call_id)
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"Successfully processed completion for task {task.task_id} "
f"(operation {message.operation_id})"
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def _handle_failure(
@@ -237,31 +295,44 @@ class ChatCompletionConsumer:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry followed by finish event
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
# Update pending operation with error using our Prisma client
error_response = ErrorResponse(
message=error_msg,
error=message.error,
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
try:
prisma = await self._ensure_prisma()
await prisma.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": error_response.model_dump_json()},
)
logger.info(
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await chat_service._mark_operation_completed(task.tool_call_id)
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"Processed failure for task {task.task_id} "
f"(operation {message.operation_id}): {error_msg}"
f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}"
)
@@ -294,9 +365,6 @@ async def publish_operation_complete(
) -> None:
"""Publish an operation completion message.
This is a helper function for testing or for services that want to
publish completion messages directly.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.

View File

@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):

View File

@@ -1,6 +1,7 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
@@ -16,7 +17,7 @@ from . import service as chat_service
from . import stream_registry
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -58,6 +59,13 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -66,6 +74,7 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -177,13 +186,14 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, or None if not found.
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
"""
session = await get_chat_session(session_id, user_id)
@@ -191,10 +201,31 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
logger.info(f"[SSE-RECONNECT] Checking for active stream in session {session_id}")
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if active_task:
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id=last_message_id,
)
logger.info(
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
f"task_id={active_task.task_id}, status={active_task.status}, "
f"last_message_id={last_message_id}"
)
else:
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
f"roles={[m.get('role') for m in messages]}, "
f"has_active_stream={active_stream_info is not None}"
)
return SessionDetailResponse(
@@ -203,6 +234,7 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -222,49 +254,136 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[SSE-RECONNECT] Created stream task for reconnection support: "
f"task_id={task_id}, session_id={session_id}"
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
# Write to Redis (subscribers will receive via pub/sub or polling)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
logger.info(
"[SSE-RECONNECT] Background AI generation completed",
extra={
"session_id": session_id,
"task_id": task_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
except Exception as e:
logger.error(
f"[SSE-RECONNECT] Error in background AI generation for session "
f"{session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(f"[SSE-RECONNECT] Started background AI generation task for {task_id}")
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
logger.error(f"Failed to subscribe to task {task_id}")
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"[SSE-RECONNECT] SSE subscriber received finish for task {task_id}"
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
# Client disconnected - that's fine, background task continues
logger.info(
f"[SSE-RECONNECT] SSE client disconnected for task {task_id}, "
f"background generation continues"
)
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -409,6 +528,11 @@ async def stream_task(
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
logger.info(
f"[SSE-RECONNECT] Client reconnecting to task stream: "
f"task_id={task_id}, last_message_id={last_message_id}"
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
@@ -417,8 +541,15 @@ async def stream_task(
)
if subscriber_queue is None:
logger.warning(
f"[SSE-RECONNECT] Task not found or access denied: task_id={task_id}"
)
raise NotFoundError(f"Task {task_id} not found or access denied.")
logger.info(
f"[SSE-RECONNECT] Successfully subscribed to task stream: task_id={task_id}"
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio

View File

@@ -1,12 +1,18 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It supports:
- Creating tasks with unique IDs for long-running operations
- Publishing stream messages to both Redis Streams and in-memory queues
- Subscribing to tasks with replay of missed messages
- Looking up tasks by operation_id for webhook callbacks
- Cross-pod real-time delivery via Redis pub/sub
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay
- Redis Pub/Sub: Real-time delivery to subscribers
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream
2. Subscribe to pub/sub channel for live updates
3. No in-memory state required on the subscribing pod
"""
import asyncio
@@ -25,13 +31,10 @@ from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track active pub/sub listeners for cross-pod delivery
_pubsub_listeners: dict[str, asyncio.Task] = {}
@dataclass
class ActiveTask:
"""Represents an active streaming task."""
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
@@ -41,22 +44,17 @@ class ActiveTask:
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
asyncio_task: asyncio.Task | None = None
# Lock for atomic status checks and subscriber management
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
# Set of subscriber queues for fan-out
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
# Module-level registry for active tasks
_active_tasks: dict[str, ActiveTask] = {}
# Redis key patterns
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for real-time delivery
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
def _get_task_meta_key(task_id: str) -> str:
@@ -75,7 +73,7 @@ def _get_operation_mapping_key(operation_id: str) -> str:
def _get_task_pubsub_channel(task_id: str) -> str:
"""Get Redis pub/sub channel for task cross-pod delivery."""
"""Get Redis pub/sub channel for task real-time delivery."""
return f"{TASK_PUBSUB_PREFIX}{task_id}"
@@ -87,7 +85,7 @@ async def create_task(
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in memory and Redis.
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
@@ -98,7 +96,7 @@ async def create_task(
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
@@ -109,10 +107,7 @@ async def create_task(
operation_id=operation_id,
)
# Store in memory registry
_active_tasks[task_id] = task
# Store metadata in Redis for durability
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
@@ -136,8 +131,7 @@ async def create_task(
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.info(
f"Created streaming task {task_id} for operation {operation_id} "
f"in session {session_id}"
f"[SSE-RECONNECT] Created task {task_id} for session {session_id} in Redis"
)
return task
@@ -147,41 +141,26 @@ async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to the task's stream.
"""Publish a chunk to Redis Stream and pub/sub channel.
Delivers to in-memory subscribers first (for real-time), then persists to
Redis Stream (for replay). This order ensures live subscribers get messages
even if Redis temporarily fails.
All delivery is via Redis - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
Redis persistence failed
The Redis Stream message ID
"""
# Deliver to in-memory subscribers FIRST for real-time updates
task = _active_tasks.get(task_id)
if task:
async with task.lock:
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id}, dropping chunk"
)
# Then persist to Redis Stream for replay (with error handling)
message_id = "0-0"
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
pubsub_channel = _get_task_pubsub_channel(task_id)
# Add to Redis Stream with auto-generated ID
# The ID format is "timestamp-sequence" which gives us ordering
# Write to Redis Stream for persistence/replay
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
@@ -189,14 +168,13 @@ async def publish_chunk(
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Publish to pub/sub for cross-pod real-time delivery
pubsub_channel = _get_task_pubsub_channel(task_id)
# Publish to pub/sub for real-time delivery
await redis.publish(pubsub_channel, chunk_json)
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
except Exception as e:
logger.error(
f"Failed to persist chunk to Redis for task {task_id}: {e}",
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
@@ -210,6 +188,8 @@ async def subscribe_to_task(
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
@@ -219,102 +199,23 @@ async def subscribe_to_task(
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
# Check in-memory first
task = _active_tasks.get(task_id)
if task:
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task.user_id}"
)
return None
# Create a new queue for this subscriber
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
# Replay from Redis Stream
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Track the last message ID we've seen for gap detection
replay_last_id = last_message_id
# Read all messages from stream starting after last_message_id
# xread returns messages with ID > last_message_id
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
if messages:
# messages format: [[stream_name, [(id, {data: json}), ...]]]
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
# Track the last message ID we've processed
replay_last_id = (
msg_id if isinstance(msg_id, str) else msg_id.decode()
)
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
# Reconstruct the appropriate response type
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Atomically check status and register subscriber under lock
# This prevents race condition where task completes between check and subscribe
should_start_pubsub = False
async with task.lock:
if task.status == "running":
# Register this subscriber for live updates
task.subscribers.add(subscriber_queue)
# Start pub/sub listener if this is the first subscriber
should_start_pubsub = len(task.subscribers) == 1
logger.debug(
f"Registered subscriber for task {task_id}, "
f"total subscribers: {len(task.subscribers)}"
)
else:
# Task is done, add finish marker
await subscriber_queue.put(StreamFinish())
# After registering, do a second read to catch any messages published
# between the first read and registration (closes the race window)
if task.status == "running":
gap_messages = await redis.xread(
{stream_key: replay_last_id}, block=0, count=1000
)
if gap_messages:
for _stream_name, stream_messages in gap_messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay gap message: {e}")
# Start pub/sub listener outside the lock to avoid deadlocks
if should_start_pubsub:
await start_pubsub_listener(task_id)
return subscriber_queue
# Try to load from Redis if not in memory
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"Task {task_id} not found in memory or Redis")
logger.warning(f"[SSE-RECONNECT] Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
logger.info(
f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}"
)
# Validate ownership
task_user_id = meta.get(b"user_id", b"").decode() or None
if user_id and task_user_id and task_user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
@@ -322,79 +223,158 @@ async def subscribe_to_task(
)
return None
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
subscriber_queue = asyncio.Queue()
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Read all messages starting after last_message_id
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Add finish marker since task is not active
await subscriber_queue.put(StreamFinish())
logger.info(
f"[SSE-RECONNECT] Task {task_id}: replayed {replayed_count} messages "
f"(last_id={replay_last_id})"
)
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
)
asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
else:
# Task is completed/failed - add finish marker
logger.info(
f"[SSE-RECONNECT] Task {task_id} is {task_status}, adding finish marker"
)
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
logger.debug(
f"[SSE-RECONNECT] Stream listener started for task {task_id}, "
f"from ID {current_id}"
)
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
logger.info(
f"[SSE-RECONNECT] Task {task_id} no longer running "
f"(status={status}), stopping listener"
)
subscriber_queue.put_nowait(StreamFinish())
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
subscriber_queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id}"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
logger.info(
f"[SSE-RECONNECT] Task {task_id} finished "
"via stream"
)
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"[SSE-RECONNECT] Stream listener cancelled for task {task_id}")
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
subscriber_queue.put_nowait(StreamFinish())
except asyncio.QueueFull:
pass
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> None:
"""Mark a task as completed and publish final event.
"""Mark a task as completed and publish finish event.
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
"""
task = _active_tasks.get(task_id)
if task:
# Acquire lock to prevent new subscribers during completion
async with task.lock:
task.status = status
# Send finish event directly to all current subscribers
finish_event = StreamFinish()
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(finish_event)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id} during completion"
)
# Clear subscribers since task is done
task.subscribers.clear()
# Stop pub/sub listener since task is done
await stop_pubsub_listener(task_id)
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
await publish_chunk(task_id, StreamFinish())
# Remove from active tasks after a short delay to allow subscribers to finish
async def _cleanup():
await asyncio.sleep(5)
_active_tasks.pop(task_id, None)
logger.info(f"Cleaned up task {task_id} from memory")
asyncio.create_task(_cleanup())
# Publish finish event (goes to Redis Stream + pub/sub)
await publish_chunk(task_id, StreamFinish())
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
logger.info(f"Marked task {task_id} as {status}")
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
logger.info(f"[SSE-RECONNECT] Marked task {task_id} as {status}")
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
@@ -408,43 +388,26 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
for task in _active_tasks.values():
if task.operation_id == operation_id:
return task
# Try Redis lookup
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if task_id:
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
# Check if task is in memory
if task_id_str in _active_tasks:
return _active_tasks[task_id_str]
logger.info(
f"[SSE-RECONNECT] find_task_by_operation_id: "
f"op_key={op_key}, task_id_from_redis={task_id!r}"
)
# Load metadata from Redis
meta_key = _get_task_meta_key(task_id_str)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not task_id:
logger.info(f"[SSE-RECONNECT] No task_id found for operation {operation_id}")
return None
if meta:
# Reconstruct task object (not fully active, but has metadata)
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=operation_id,
status=meta.get(b"status", b"running").decode(), # type: ignore
)
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
logger.info(f"[SSE-RECONNECT] Looking up task by task_id={task_id_str}")
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID.
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
@@ -452,27 +415,127 @@ async def get_task(task_id: str) -> ActiveTask | None:
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
if task_id in _active_tasks:
return _active_tasks[task_id]
# Try Redis lookup
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=meta.get(b"operation_id", b"").decode(),
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
logger.info(
f"[SSE-RECONNECT] get_task: meta_key={meta_key}, "
f"meta_keys={list(meta.keys()) if meta else 'empty'}, "
f"meta={meta}"
)
if not meta:
logger.info(f"[SSE-RECONNECT] No metadata found for task {task_id}")
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
task = ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
logger.info(
f"[SSE-RECONNECT] get_task returning: task_id={task.task_id}, "
f"session_id={task.session_id}, operation_id={task.operation_id}"
)
return task
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
logger.info(f"[SSE-RECONNECT] Looking for active task for session {session_id}")
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{TASK_META_PREFIX}*", count=100
)
return None
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
# Log tasks found for this session
if task_session_id == session_id:
logger.info(
f"[SSE-RECONNECT] Found task for session: "
f"task_id={task_id}, status={task_status}"
)
if task_session_id == session_id and task_status == "running":
# Validate ownership
if user_id and task_user_id and task_user_id != user_id:
logger.info(f"[SSE-RECONNECT] Task {task_id} ownership mismatch")
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
logger.info(
f"[SSE-RECONNECT] Found active task: task_id={task_id}, "
f"last_message_id={last_id}"
)
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
logger.info(
f"[SSE-RECONNECT] No active task found for session {session_id} "
f"(checked {tasks_checked} tasks)"
)
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
@@ -533,116 +596,30 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Associate an asyncio.Task with an ActiveTask.
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to associate
asyncio_task: The asyncio Task to track
"""
task = _active_tasks.get(task_id)
if task:
task.asyncio_task = asyncio_task
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Unsubscribe a queue from a task's stream.
"""Clean up when a subscriber disconnects.
Should be called when a client disconnects to clean up resources.
Also stops the pub/sub listener if there are no more local subscribers.
With Redis-based pub/sub, there's no explicit unsubscription needed.
The pub/sub listener task will be garbage collected when the subscriber
stops reading from the queue.
Args:
task_id: Task ID to unsubscribe from
subscriber_queue: The queue to remove from subscribers
task_id: Task ID
subscriber_queue: The subscriber's queue (unused, kept for API compat)
"""
task = _active_tasks.get(task_id)
if task:
async with task.lock:
task.subscribers.discard(subscriber_queue)
remaining = len(task.subscribers)
logger.debug(
f"Unsubscribed from task {task_id}, "
f"remaining subscribers: {remaining}"
)
# Stop pub/sub listener if no more local subscribers
if remaining == 0:
await stop_pubsub_listener(task_id)
async def start_pubsub_listener(task_id: str) -> None:
"""Start listening to Redis pub/sub for cross-pod delivery.
This enables real-time updates when another pod publishes chunks for a task
that has local subscribers on this pod.
Args:
task_id: Task ID to listen for
"""
if task_id in _pubsub_listeners:
return # Already listening
task = _active_tasks.get(task_id)
if not task:
return
async def _listener():
try:
redis = await get_redis_async()
pubsub = redis.pubsub()
channel = _get_task_pubsub_channel(task_id)
await pubsub.subscribe(channel)
logger.debug(f"Started pub/sub listener for task {task_id}")
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
chunk_data = orjson.loads(message["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
# Deliver to local subscribers
local_task = _active_tasks.get(task_id)
if local_task:
async with local_task.lock:
for queue in local_task.subscribers:
try:
queue.put_nowait(chunk)
except asyncio.QueueFull:
pass
# Stop listening if this was a finish event
if isinstance(chunk, StreamFinish):
break
except Exception as e:
logger.warning(f"Error processing pub/sub message: {e}")
await pubsub.unsubscribe(channel)
await pubsub.close()
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
finally:
_pubsub_listeners.pop(task_id, None)
logger.debug(f"Stopped pub/sub listener for task {task_id}")
listener_task = asyncio.create_task(_listener())
_pubsub_listeners[task_id] = listener_task
async def stop_pubsub_listener(task_id: str) -> None:
"""Stop the pub/sub listener for a task.
Args:
task_id: Task ID to stop listening for
"""
listener = _pubsub_listeners.pop(task_id, None)
if listener and not listener.done():
listener.cancel()
try:
await listener
except asyncio.CancelledError:
pass
logger.debug(f"Cancelled pub/sub listener for task {task_id}")
# No-op - pub/sub listener cleans up automatically
logger.debug(f"[SSE-RECONNECT] Subscriber disconnected from task {task_id}")

View File

@@ -1079,7 +1079,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1214,7 +1214,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -6313,6 +6313,16 @@
"title": "AccuracyTrendsResponse",
"description": "Response model for accuracy trends and alerts."
},
"ActiveStreamInfo": {
"properties": {
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" }
},
"type": "object",
"required": ["task_id", "last_message_id"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
"AddUserCreditsResponse": {
"properties": {
"new_balance": { "type": "integer", "title": "New Balance" },
@@ -9808,6 +9818,12 @@
"items": { "additionalProperties": true, "type": "object" },
"type": "array",
"title": "Messages"
},
"active_stream": {
"anyOf": [
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
}
},
"type": "object",

View File

@@ -25,6 +25,7 @@ export function Chat({
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const {
session,
messages,
isLoading,
isCreating,
@@ -36,6 +37,21 @@ export function Chat({
startPollingForOperation,
} = useChat({ urlSessionId });
// Extract active stream info for reconnection
const activeStream = (session as { active_stream?: { task_id: string; last_message_id: string } })?.active_stream;
// Debug logging for SSE reconnection
if (session) {
console.info("[SSE-RECONNECT] Session loaded:", {
sessionId,
hasActiveStream: !!activeStream,
activeStream: activeStream ? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
} : null,
});
}
useEffect(() => {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
@@ -83,6 +99,10 @@ export function Chat({
className="flex-1"
onStreamingChange={onStreamingChange}
onOperationStarted={startPollingForOperation}
activeStream={activeStream ? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
} : undefined}
/>
)}
</main>

View File

@@ -391,6 +391,12 @@ export const useChatStore = create<ChatStore>((set, get) => ({
lastMessageId = "0-0", // Redis Stream ID format
onChunk,
) {
console.info("[SSE-RECONNECT] reconnectToTask called:", {
sessionId,
taskId,
lastMessageId,
});
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
@@ -435,8 +441,14 @@ export const useChatStore = create<ChatStore>((set, get) => ({
completedStreams: newCompletedStreams,
});
console.info("[SSE-RECONNECT] Starting executeTaskReconnect...");
try {
await executeTaskReconnect(stream, taskId, lastMessageId);
console.info("[SSE-RECONNECT] executeTaskReconnect completed:", {
sessionId,
taskId,
streamStatus: stream.status,
});
} finally {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
@@ -468,9 +480,16 @@ export const useChatStore = create<ChatStore>((set, get) => ({
// Clear active task on completion
const taskState = get();
const newActiveTasks = new Map(taskState.activeTasks);
const hadActiveTask = newActiveTasks.has(sessionId);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
if (hadActiveTask) {
console.info(
`[ChatStore] Cleared active task for session ${sessionId} ` +
`(stream status: ${stream.status})`,
);
}
}
}
}

View File

@@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
export interface StreamChunk {
type:
| "stream_start"
| "text_chunk"
| "text_ended"
| "tool_call"
@@ -15,6 +16,7 @@ export interface StreamChunk {
| "error"
| "usage"
| "stream_end";
taskId?: string; // Task ID for SSE reconnection
timestamp?: string;
content?: string;
message?: string;
@@ -41,7 +43,7 @@ export interface StreamChunk {
}
export type VercelStreamChunk =
| { type: "start"; messageId: string }
| { type: "start"; messageId: string; taskId?: string }
| { type: "finish" }
| { type: "text-start"; id: string }
| { type: "text-delta"; id: string; delta: string }

View File

@@ -17,6 +17,11 @@ export interface ChatContainerProps {
className?: string;
onStreamingChange?: (isStreaming: boolean) => void;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
};
}
export function ChatContainer({
@@ -26,6 +31,7 @@ export function ChatContainer({
className,
onStreamingChange,
onOperationStarted,
activeStream,
}: ChatContainerProps) {
const {
messages,
@@ -41,6 +47,7 @@ export function ChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
});
useEffect(() => {

View File

@@ -34,6 +34,22 @@ export function createStreamEventDispatcher(
}
switch (chunk.type) {
case "stream_start":
// Store task ID for SSE reconnection
if (chunk.taskId && deps.onActiveTaskStarted) {
console.info("[ChatStream] Stream started with task ID:", {
sessionId: deps.sessionId,
taskId: chunk.taskId,
});
deps.onActiveTaskStarted({
taskId: chunk.taskId,
operationId: chunk.taskId, // Use taskId as operationId for chat streams
toolName: "chat",
toolCallId: "chat_stream",
});
}
break;
case "text_chunk":
handleTextChunk(chunk, deps);
break;
@@ -56,7 +72,8 @@ export function createStreamEventDispatcher(
break;
case "stream_end":
console.info("[ChatStream] Stream ended:", {
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
console.info("[SSE-RECONNECT] Stream ended:", {
sessionId: deps.sessionId,
hasResponse: deps.hasResponseRef.current,
chunkCount: deps.streamingChunksRef.current.length,

View File

@@ -221,8 +221,10 @@ export function handleStreamEnd(
_chunk: StreamChunk,
deps: HandlerDependencies,
) {
console.info("[SSE-RECONNECT] handleStreamEnd called, resetting streaming state");
const completedContent = deps.streamingChunksRef.current.join("");
if (!completedContent.trim() && !deps.hasResponseRef.current) {
console.info("[SSE-RECONNECT] No content received, adding placeholder message");
deps.setMessages((prev) => [
...prev,
{
@@ -261,10 +263,14 @@ export function handleStreamEnd(
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
const errorMessage = chunk.message || chunk.content || "An error occurred";
console.error("Stream error:", errorMessage);
console.error("[ChatStream] Stream error:", errorMessage, {
sessionId: deps.sessionId,
chunk,
});
if (isRegionBlockedError(chunk)) {
deps.setIsRegionBlockedModalOpen(true);
}
console.info("[ChatStream] Resetting streaming state due to error");
deps.setIsStreamingInitiated(false);
deps.setHasTextChunks(false);
deps.setStreamingChunks([]);

View File

@@ -41,6 +41,11 @@ interface Args {
initialMessages: SessionDetailResponse["messages"];
initialPrompt?: string;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
};
}
export function useChatContainer({
@@ -48,6 +53,7 @@ export function useChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
}: Args) {
const [messages, setMessages] = useState<ChatMessageData[]>([]);
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
@@ -69,6 +75,8 @@ export function useChatContainer({
const getActiveTask = useChatStore((s) => s.getActiveTask);
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
const isStreaming = isStreamingInitiated || hasTextChunks;
// Track whether we've already connected to this activeStream to avoid duplicate connections
const connectedActiveStreamRef = useRef<string | null>(null);
// Callback to store active task info for SSE reconnection
function handleActiveTaskStarted(taskInfo: {
@@ -88,25 +96,131 @@ export function useChatContainer({
useEffect(
function handleSessionChange() {
if (sessionId === previousSessionIdRef.current) return;
const isSessionChange = sessionId !== previousSessionIdRef.current;
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
console.info("[SSE-RECONNECT] handleSessionChange effect running:", {
sessionId,
previousSessionId: previousSessionIdRef.current,
isSessionChange,
hasActiveStream: !!activeStream,
activeStreamTaskId: activeStream?.taskId,
connectedActiveStream: connectedActiveStreamRef.current,
});
// Handle session change - reset state
if (isSessionChange) {
console.info("[SSE-RECONNECT] Session changed, resetting state");
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
}
previousSessionIdRef.current = sessionId;
connectedActiveStreamRef.current = null; // Reset connected stream tracker
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
}
previousSessionIdRef.current = sessionId;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
if (!sessionId) return;
if (!sessionId) {
console.info("[SSE-RECONNECT] No sessionId, skipping reconnection check");
return;
}
// Check if there's an active task for this session that we should reconnect to
// Priority 1: Check if server told us there's an active stream (most authoritative)
// Also handles the case where activeStream arrives after initial session load
if (activeStream) {
// Skip if we've already connected to this exact stream
// Check and set immediately to prevent race conditions from effect re-runs
const streamKey = `${sessionId}:${activeStream.taskId}`;
if (connectedActiveStreamRef.current === streamKey) {
console.info(
"[SSE-RECONNECT] Already connected to this stream, skipping:",
{ streamKey },
);
return;
}
// Also skip if there's already an active stream for this session in the store
// (handles case where effect re-runs due to activeStreams state change)
const existingStream = activeStreams.get(sessionId);
if (existingStream && existingStream.status === "streaming") {
console.info(
"[SSE-RECONNECT] Active stream already exists in store, skipping:",
{ sessionId, status: existingStream.status },
);
connectedActiveStreamRef.current = streamKey;
return;
}
// Set immediately after check to prevent race conditions
connectedActiveStreamRef.current = streamKey;
console.info(
"[SSE-RECONNECT] Server reports active stream, initiating reconnection:",
{
sessionId,
taskId: activeStream.taskId,
lastMessageId: activeStream.lastMessageId,
streamKey,
},
);
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
// Store this as the active task for future reconnects
setActiveTask(sessionId, {
taskId: activeStream.taskId,
operationId: activeStream.taskId,
toolName: "chat",
lastMessageId: activeStream.lastMessageId,
});
// Reconnect to the task stream
console.info("[SSE-RECONNECT] Calling reconnectToTask...");
reconnectToTask(
sessionId,
activeStream.taskId,
activeStream.lastMessageId,
dispatcher,
);
return;
}
// Only check localStorage/in-memory on session change, not on every render
if (!isSessionChange) {
console.info(
"[SSE-RECONNECT] No active stream and not a session change, skipping fallbacks",
);
return;
}
// Priority 2: Check localStorage for active task (client-side state)
console.info("[SSE-RECONNECT] Checking localStorage for active task...");
const activeTask = getActiveTask(sessionId);
if (activeTask) {
console.info(
"[SSE-RECONNECT] Found active task in localStorage, attempting reconnect:",
{
sessionId,
taskId: activeTask.taskId,
lastMessageId: activeTask.lastMessageId,
},
);
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
@@ -122,6 +236,7 @@ export function useChatContainer({
setIsStreamingInitiated(true);
// Reconnect to the task stream
console.info("[SSE-RECONNECT] Calling reconnectToTask from localStorage...");
reconnectToTask(
sessionId,
activeTask.taskId,
@@ -129,11 +244,20 @@ export function useChatContainer({
dispatcher,
);
return;
} else {
console.info("[SSE-RECONNECT] No active task in localStorage");
}
// Otherwise check for an in-memory active stream
const activeStream = activeStreams.get(sessionId);
if (!activeStream || activeStream.status !== "streaming") return;
// Priority 3: Check for an in-memory active stream (same-tab scenario)
console.info("[SSE-RECONNECT] Checking in-memory active streams...");
const inMemoryStream = activeStreams.get(sessionId);
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
console.info("[SSE-RECONNECT] No in-memory active stream found:", {
hasStream: !!inMemoryStream,
status: inMemoryStream?.status,
});
return;
}
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
@@ -160,6 +284,8 @@ export function useChatContainer({
onOperationStarted,
getActiveTask,
reconnectToTask,
activeStream,
setActiveTask,
],
);

View File

@@ -167,8 +167,15 @@ export async function executeTaskReconnect(
): Promise<void> {
const { abortController } = stream;
console.info("[SSE-RECONNECT] executeTaskReconnect starting:", {
taskId,
lastMessageId,
retryCount,
});
try {
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
console.info("[SSE-RECONNECT] Fetching task stream:", { url });
const response = await fetch(url, {
method: "GET",
@@ -178,15 +185,33 @@ export async function executeTaskReconnect(
signal: abortController.signal,
});
console.info("[SSE-RECONNECT] Task stream response:", {
status: response.status,
ok: response.ok,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(errorText || `HTTP ${response.status}`);
console.error("[SSE-RECONNECT] Task stream error response:", {
status: response.status,
errorText,
});
// Don't retry on 404 (task not found) or 403 (access denied) - these are permanent errors
const isPermanentError =
response.status === 404 || response.status === 403;
const error = new Error(errorText || `HTTP ${response.status}`);
(error as Error & { status?: number }).status = response.status;
(error as Error & { isPermanent?: boolean }).isPermanent =
isPermanentError;
throw error;
}
if (!response.body) {
throw new Error("Response body is null");
}
console.info("[SSE-RECONNECT] Task stream connected, reading chunks...");
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
@@ -195,6 +220,7 @@ export async function executeTaskReconnect(
const { done, value } = await reader.read();
if (done) {
console.info("[SSE-RECONNECT] Task stream reader done (connection closed)");
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -208,6 +234,7 @@ export async function executeTaskReconnect(
const data = parseSSELine(line);
if (data !== null) {
if (data === "[DONE]") {
console.info("[SSE-RECONNECT] Task stream received [DONE] signal");
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -220,14 +247,24 @@ export async function executeTaskReconnect(
const chunk = normalizeStreamChunk(rawChunk);
if (!chunk) continue;
// Log first few chunks for debugging
if (stream.chunks.length < 3) {
console.info("[SSE-RECONNECT] Task stream chunk received:", {
type: chunk.type,
chunkIndex: stream.chunks.length,
});
}
notifySubscribers(stream, chunk);
if (chunk.type === "stream_end") {
console.info("[SSE-RECONNECT] Task stream completed via stream_end chunk");
stream.status = "completed";
return;
}
if (chunk.type === "error") {
console.error("[SSE-RECONNECT] Task stream error chunk:", chunk);
stream.status = "error";
stream.error = new Error(
chunk.message || chunk.content || "Stream error",
@@ -250,17 +287,35 @@ export async function executeTaskReconnect(
return;
}
if (retryCount < MAX_RETRIES) {
// Check if this is a permanent error (404/403) that shouldn't be retried
const isPermanentError =
err instanceof Error &&
(err as Error & { isPermanent?: boolean }).isPermanent;
if (!isPermanentError && retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1);
return executeTaskReconnect(
stream,
taskId,
lastMessageId,
retryCount + 1,
);
}
// Log permanent errors differently for debugging
if (isPermanentError) {
console.log(
`[StreamExecutor] Task reconnect failed permanently (task not found or access denied): ${(err as Error).message}`,
);
}
stream.status = "error";
stream.error = err instanceof Error ? err : new Error("Task reconnect failed");
stream.error =
err instanceof Error ? err : new Error("Task reconnect failed");
notifySubscribers(stream, {
type: "error",
message: stream.error.message,

View File

@@ -28,7 +28,8 @@ export function normalizeStreamChunk(
switch (chunk.type) {
case "text-delta":
return { type: "text_chunk", content: chunk.delta };
// Backend sends "content", Vercel AI SDK sends "delta"
return { type: "text_chunk", content: chunk.delta || chunk.content };
case "text-end":
return { type: "text_ended" };
case "tool-input-available":
@@ -63,6 +64,10 @@ export function normalizeStreamChunk(
case "finish":
return { type: "stream_end" };
case "start":
// Start event with optional taskId for reconnection
return chunk.taskId
? { type: "stream_start", taskId: chunk.taskId }
: null;
case "text-start":
return null;
case "tool-input-start":