Compare commits

...

15 Commits

Author SHA1 Message Date
Swifty
070d56166c simplify and ensure agents are added to store 2026-02-02 16:15:38 +01:00
Swifty
ef3fab57fd update to only use redis for integration 2026-02-02 15:33:36 +01:00
Swifty
e812ee9265 Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-02-02 14:12:15 +01:00
Swifty
eb3872d78b doc 2026-02-02 14:09:00 +01:00
Swifty
2cdfe90c56 fixing sse reconnection 2026-02-02 14:08:20 +01:00
Swifty
d1da7fe5da remove call for onbaording step 2026-01-30 12:15:35 +01:00
Swifty
11e27cfdcf Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-30 12:01:45 +01:00
Swifty
0be5fedc86 updating sse reconection logic be 2026-01-30 11:58:42 +01:00
Swifty
f2e81648b5 updating SSE reconnection logic 2026-01-30 11:58:25 +01:00
Swifty
bb608ea60d pr comments 2026-01-29 22:29:17 +01:00
Swifty
46af3b94f2 Merge branch 'swiftyos/sse-long-running-tasks' of github.com:Significant-Gravitas/AutoGPT into swiftyos/sse-long-running-tasks 2026-01-29 18:03:01 +01:00
Swifty
083cceca0f fixing edge cases 2026-01-29 18:02:21 +01:00
Swifty
06758adefd Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-29 13:33:32 +01:00
Swifty
c01c29a059 fmt issues 2026-01-29 13:28:01 +01:00
Swifty
d738059da8 added long running task support 2026-01-29 10:24:14 +01:00
32 changed files with 3375 additions and 312 deletions

View File

@@ -0,0 +1,303 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods.
"""
import asyncio
import logging
import os
import uuid
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message will be redelivered to another consumer
# or can be claimed after timeout
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,255 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
def serialize_result(result: dict | str | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize (dict, string, or None)
Returns:
JSON string representation of the result
"""
if isinstance(result, str):
return result
if result:
return orjson.dumps(result).decode("utf-8")
return '{"status": "completed"}'
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning(
"[COMPLETION] Cannot save agent: no user_id in task"
)
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": agent_json, # Include the JSON so user can retry
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output
result_output = result if result else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
result_str = serialize_result(result)
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
await prisma_client.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": result_str},
)
logger.info(
f"[COMPLETION] Updated tool message for session {task.session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
await prisma_client.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": error_response.model_dump_json()},
)
logger.info(
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
task_pubsub_prefix: str = Field(
default="chat:task:pubsub:",
description="Prefix for task pub/sub channel names",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
@@ -82,6 +124,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse):

View File

@@ -1,19 +1,23 @@
"""Chat API routes for chat session management and streaming via SSE."""
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig()
@@ -55,6 +59,13 @@ class CreateSessionResponse(BaseModel):
user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages."""
@@ -63,6 +74,7 @@ class SessionDetailResponse(BaseModel):
updated_at: str
user_id: str | None
messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel):
@@ -81,6 +93,14 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -166,13 +186,14 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access.
Returns:
SessionDetailResponse: Details for the requested session, or None if not found.
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
"""
session = await get_chat_session(session_id, user_id)
@@ -180,10 +201,45 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages]
# Check if there's an active stream for this session
active_stream_info = None
logger.info(f"[SSE-RECONNECT] Checking for active stream in session {session_id}")
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
original_count = len(messages)
messages = messages[:-1]
logger.info(
f"[SSE-RECONNECT] Filtered out in-progress assistant message "
f"(was {original_count} messages, now {len(messages)})"
)
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
)
logger.info(
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
f"task_id={active_task.task_id}, status={active_task.status}, "
f"last_message_id=0-0 (replay from start)"
)
else:
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")
logger.info(
f"Returning session {session_id}: "
f"message_count={len(messages)}, "
f"roles={[m.get('role') for m in messages]}"
f"roles={[m.get('role') for m in messages]}, "
f"has_active_stream={active_stream_info is not None}"
)
return SessionDetailResponse(
@@ -192,6 +248,7 @@ async def get_session(
updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None,
messages=messages,
active_stream=active_stream_info,
)
@@ -211,49 +268,136 @@ async def stream_chat_post(
- Tool call UI elements (if invoked)
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
logger.info(
f"[SSE-RECONNECT] Created stream task for reconnection support: "
f"task_id={task_id}, session_id={session_id}"
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
# Write to Redis (subscribers will receive via pub/sub or polling)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
logger.info(
"[SSE-RECONNECT] Background AI generation completed",
extra={
"session_id": session_id,
"task_id": task_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
except Exception as e:
logger.error(
f"[SSE-RECONNECT] Error in background AI generation for session "
f"{session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(f"[SSE-RECONNECT] Started background AI generation task for {task_id}")
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]:
try:
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
logger.error(f"Failed to subscribe to task {task_id}")
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
# Read from the subscriber queue and yield to SSE
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"[SSE-RECONNECT] SSE subscriber received finish for task {task_id}"
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except GeneratorExit:
# Client disconnected - that's fine, background task continues
logger.info(
f"[SSE-RECONNECT] SSE client disconnected for task {task_id}, "
f"background generation continues"
)
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
@@ -366,6 +510,207 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
logger.info(
f"[SSE-RECONNECT] Client reconnecting to task stream: "
f"task_id={task_id}, last_message_id={last_message_id}"
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
logger.warning(
f"[SSE-RECONNECT] Task not found or access denied: task_id={task_id}"
)
raise NotFoundError(f"Task {task_id} not found or access denied.")
logger.info(
f"[SSE-RECONNECT] Successfully subscribed to task stream: task_id={task_id}"
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
chunk_count = 0
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
chunk_count += 1
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"Task stream completed for task {task_id}, "
f"chunk_count={chunk_count}"
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Health Check ==========

View File

@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from . import stream_registry
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -1610,8 +1611,9 @@ async def _yield_tool_call(
)
return
# Generate operation ID
# Generate operation ID and task ID
operation_id = str(uuid_module.uuid4())
task_id = str(uuid_module.uuid4())
# Build a user-friendly message based on tool and arguments
if tool_name == "create_agent":
@@ -1654,6 +1656,16 @@ async def _yield_tool_call(
# Wrap session save and task creation in try-except to release lock on failure
try:
# Create task in stream registry for SSE reconnection support
await stream_registry.create_task(
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Save assistant message with tool_call FIRST (required by LLM)
assistant_message = ChatMessage(
role="assistant",
@@ -1675,23 +1687,27 @@ async def _yield_tool_call(
session.messages.append(pending_message)
await upsert_chat_session(session)
logger.info(
f"Saved pending operation {operation_id} for tool {tool_name} "
f"in session {session.session_id}"
f"Saved pending operation {operation_id} (task_id={task_id}) "
f"for tool {tool_name} in session {session.session_id}"
)
# Store task reference in module-level set to prevent GC before completion
task = asyncio.create_task(
_execute_long_running_tool(
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=arguments,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
# Associate the asyncio task with the stream registry task
await stream_registry.set_task_asyncio_task(task_id, bg_task)
except Exception as e:
# Roll back appended messages to prevent data corruption on subsequent saves
if (
@@ -1709,6 +1725,11 @@ async def _yield_tool_call(
# Release the Redis lock since the background task won't be spawned
await _mark_operation_completed(tool_call_id)
# Mark stream registry task as failed if it was created
try:
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception:
pass
logger.error(
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
)
@@ -1722,6 +1743,7 @@ async def _yield_tool_call(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id, # Include task_id for SSE reconnection
).model_dump_json(),
success=True,
)
@@ -1791,6 +1813,9 @@ async def _execute_long_running_tool(
This function runs independently of the SSE connection, so the operation
survives if the user closes their browser tab.
NOTE: This is the legacy function without stream registry support.
Use _execute_long_running_tool_with_streaming for new implementations.
"""
try:
# Load fresh session (not stale reference)
@@ -1843,6 +1868,128 @@ async def _execute_long_running_tool(
await _mark_operation_completed(tool_call_id)
async def _execute_long_running_tool_with_streaming(
tool_name: str,
parameters: dict[str, Any],
tool_call_id: str,
operation_id: str,
task_id: str,
session_id: str,
user_id: str | None,
) -> None:
"""Execute a long-running tool with stream registry support for SSE reconnection.
This function runs independently of the SSE connection, publishes progress
to the stream registry, and survives if the user closes their browser tab.
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
If the external service returns a 202 Accepted (async), this function exits
early and lets the RabbitMQ completion consumer handle the rest.
"""
# Track whether we delegated to async processing - if so, the RabbitMQ
# completion consumer will handle cleanup, not us
delegated_to_async = False
try:
# Load fresh session (not stale reference)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(task_id, status="failed")
return
# Pass operation_id and task_id to the tool for async processing
enriched_parameters = {
**parameters,
"_operation_id": operation_id,
"_task_id": task_id,
}
# Execute the actual tool
result = await execute_tool(
tool_name=tool_name,
parameters=enriched_parameters,
tool_call_id=tool_call_id,
user_id=user_id,
session=session,
)
# Check if the tool result indicates async processing
# (e.g., Agent Generator returned 202 Accepted)
try:
result_data = orjson.loads(result.output) if result.output else {}
if result_data.get("status") == "accepted":
logger.info(
f"Tool {tool_name} delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id}). "
f"RabbitMQ completion consumer will handle the rest."
)
# Don't publish result, don't continue with LLM, and don't cleanup
# The RabbitMQ consumer will handle everything when the external
# service completes and publishes to the queue
delegated_to_async = True
return
except (orjson.JSONDecodeError, TypeError):
pass # Not JSON or not async - continue normally
# Publish tool result to stream registry
await stream_registry.publish_chunk(task_id, result)
# Update the pending message with result
result_str = (
result.output
if isinstance(result.output, str)
else orjson.dumps(result.output).decode("utf-8")
)
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=result_str,
)
logger.info(
f"Background tool {tool_name} completed for session {session_id} "
f"(task_id={task_id})"
)
# Generate LLM continuation and stream chunks to registry
await _generate_llm_continuation_with_streaming(
session_id=session_id,
user_id=user_id,
task_id=task_id,
)
# Mark task as completed in stream registry
await stream_registry.mark_task_completed(task_id, status="completed")
except Exception as e:
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
error_response = ErrorResponse(
message=f"Tool {tool_name} failed: {str(e)}",
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed in stream registry
await stream_registry.mark_task_completed(task_id, status="failed")
finally:
# Only cleanup if we didn't delegate to async processing
# For async path, the RabbitMQ completion consumer handles cleanup
if not delegated_to_async:
await _mark_operation_completed(tool_call_id)
async def _update_pending_operation(
session_id: str,
tool_call_id: str,
@@ -1969,3 +2116,128 @@ async def _generate_llm_continuation(
except Exception as e:
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,
task_id: str,
) -> None:
"""Generate an LLM response with streaming to the stream registry.
This is called by background tasks to continue the conversation
after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them.
"""
import uuid as uuid_module
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for LLM continuation")
return
# Build system prompt
system_prompt, _ = await _build_system_prompt(user_id)
# Build messages in OpenAI format
messages = session.to_openai_messages()
if system_prompt:
from openai.types.chat import ChatCompletionSystemMessageParam
system_message = ChatCompletionSystemMessageParam(
role="system",
content=system_prompt,
)
messages = [system_message] + messages
# Build extra_body for tracing
extra_body: dict[str, Any] = {
"posthogProperties": {
"environment": settings.config.app_env.value,
},
}
if user_id:
extra_body["user"] = user_id[:128]
extra_body["posthogDistinctId"] = user_id
if session_id:
extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
stream = await client.chat.completions.create(
model=config.model,
messages=cast(list[ChatCompletionMessageParam], messages),
extra_body=extra_body,
stream=True,
)
assistant_content = ""
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
delta = chunk.choices[0].delta.content
assistant_content += delta
# Publish delta to stream registry
await stream_registry.publish_chunk(
task_id,
StreamTextDelta(id=text_block_id, delta=delta),
)
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
if assistant_content:
# Reload session from DB to avoid race condition with user messages
fresh_session = await get_chat_session(session_id, user_id)
if not fresh_session:
logger.error(
f"Session {session_id} disappeared during LLM continuation"
)
return
# Save assistant message to database
assistant_message = ChatMessage(
role="assistant",
content=assistant_content,
)
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
logger.info(
f"Generated streaming LLM continuation for session {session_id} "
f"(task_id={task_id}), response length: {len(assistant_content)}"
)
else:
logger.warning(
f"Streaming LLM continuation returned empty response for {session_id}"
)
except Exception as e:
logger.error(
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -0,0 +1,614 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay
- Redis Pub/Sub: Real-time delivery to subscribers
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream
2. Subscribe to pub/sub channel for live updates
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
def _get_task_pubsub_channel(task_id: str) -> str:
"""Get Redis pub/sub channel for task real-time delivery."""
return f"{config.task_pubsub_prefix}{task_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.info(
f"[SSE-RECONNECT] Created task {task_id} for session {session_id} in Redis"
)
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream and pub/sub channel.
All delivery is via Redis - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
pubsub_channel = _get_task_pubsub_channel(task_id)
# Write to Redis Stream for persistence/replay
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Publish to pub/sub for real-time delivery
await redis.publish(pubsub_channel, chunk_json)
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"[SSE-RECONNECT] Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
logger.info(f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}")
# Validate ownership
if user_id and task_user_id and task_user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.info(
f"[SSE-RECONNECT] Task {task_id}: replayed {replayed_count} messages "
f"(last_id={replay_last_id})"
)
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
logger.info(
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
)
asyncio.create_task(_stream_listener(task_id, subscriber_queue, replay_last_id))
else:
# Task is completed/failed - add finish marker
logger.info(
f"[SSE-RECONNECT] Task {task_id} is {task_status}, adding finish marker"
)
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
logger.debug(
f"[SSE-RECONNECT] Stream listener started for task {task_id}, "
f"from ID {current_id}"
)
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
logger.info(
f"[SSE-RECONNECT] Task {task_id} no longer running "
f"(status={status}), stopping listener"
)
subscriber_queue.put_nowait(StreamFinish())
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
subscriber_queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id}"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
logger.info(
f"[SSE-RECONNECT] Task {task_id} finished "
"via stream"
)
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"[SSE-RECONNECT] Stream listener cancelled for task {task_id}")
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
subscriber_queue.put_nowait(StreamFinish())
except asyncio.QueueFull:
pass
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> None:
"""Mark a task as completed and publish finish event.
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
"""
# Publish finish event (goes to Redis Stream + pub/sub)
await publish_chunk(task_id, StreamFinish())
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
logger.info(f"[SSE-RECONNECT] Marked task {task_id} as {status}")
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
logger.info(
f"[SSE-RECONNECT] find_task_by_operation_id: "
f"op_key={op_key}, task_id_from_redis={task_id!r}"
)
if not task_id:
logger.info(f"[SSE-RECONNECT] No task_id found for operation {operation_id}")
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
logger.info(f"[SSE-RECONNECT] Looking up task by task_id={task_id_str}")
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
logger.info(
f"[SSE-RECONNECT] get_task: meta_key={meta_key}, "
f"meta_keys={list(meta.keys()) if meta else 'empty'}, "
f"meta={meta}"
)
if not meta:
logger.info(f"[SSE-RECONNECT] No metadata found for task {task_id}")
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
task = ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
logger.info(
f"[SSE-RECONNECT] get_task returning: task_id={task.task_id}, "
f"session_id={task.session_id}, operation_id={task.operation_id}"
)
return task
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
logger.info(f"[SSE-RECONNECT] Looking for active task for session {session_id}")
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
# Log tasks found for this session
if task_session_id == session_id:
logger.info(
f"[SSE-RECONNECT] Found task for session: "
f"task_id={task_id}, status={task_status}"
)
if task_session_id == session_id and task_status == "running":
# Validate ownership
if user_id and task_user_id and task_user_id != user_id:
logger.info(f"[SSE-RECONNECT] Task {task_id} ownership mismatch")
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
logger.info(
f"[SSE-RECONNECT] Found active task: task_id={task_id}, "
f"last_message_id={last_id}"
)
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
logger.info(
f"[SSE-RECONNECT] No active task found for session {session_id} "
f"(checked {tasks_checked} tasks)"
)
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
chunk_type = chunk_data.get("type")
try:
if chunk_type == ResponseType.START.value:
return StreamStart(**chunk_data)
elif chunk_type == ResponseType.FINISH.value:
return StreamFinish(**chunk_data)
elif chunk_type == ResponseType.TEXT_START.value:
return StreamTextStart(**chunk_data)
elif chunk_type == ResponseType.TEXT_DELTA.value:
return StreamTextDelta(**chunk_data)
elif chunk_type == ResponseType.TEXT_END.value:
return StreamTextEnd(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
return StreamToolInputStart(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
return StreamToolInputAvailable(**chunk_data)
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
return StreamToolOutputAvailable(**chunk_data)
elif chunk_type == ResponseType.ERROR.value:
return StreamError(**chunk_data)
elif chunk_type == ResponseType.USAGE.value:
return StreamUsage(**chunk_data)
elif chunk_type == ResponseType.HEARTBEAT.value:
return StreamHeartbeat(**chunk_data)
else:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
With Redis-based pub/sub, there's no explicit unsubscription needed.
The pub/sub listener task will be garbage collected when the subscriber
stops reading from the queue.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue (unused, kept for API compat)
"""
# No-op - pub/sub listener cleans up automatically
logger.debug(f"[SSE-RECONNECT] Subscriber disconnected from task {task_id}")

View File

@@ -549,15 +549,19 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -565,8 +569,13 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -806,6 +815,8 @@ async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -818,10 +829,12 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -829,5 +842,9 @@ async def generate_agent_patch(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request, current_agent, _to_dict_list(library_agents)
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -213,24 +213,45 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict on success, or error dict {"type": "error", ...} on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -262,6 +283,8 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -269,21 +292,40 @@ async def generate_agent_patch_external(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict, or error dict on error
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()

View File

@@ -18,6 +18,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool):
logger.warning(f"Failed to enrich library agents from steps: {e}")
try:
agent_json = await generate_agent(decomposition_result, library_agents)
agent_json = await generate_agent(
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))

View File

@@ -17,6 +17,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -104,6 +105,10 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -149,7 +154,11 @@ class EditAgentTool(BaseTool):
try:
result = await generate_agent_patch(
update_request, current_agent, library_agents
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -169,6 +178,20 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")

View File

@@ -352,11 +352,15 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -380,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -40,6 +40,10 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for RabbitMQ notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
@@ -70,41 +69,16 @@ export function useCopilotShell() {
});
const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
const pendingActionRef = useRef<(() => void) | null>(null);
async function stopCurrentStream() {
if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId);
});
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
@@ -114,7 +88,12 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function startNewChat() {
function handleNewChatClick() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
@@ -123,32 +102,6 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return {
isMobile,
isDrawerOpen,

View File

@@ -4,53 +4,25 @@ import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
export const useCopilotStore = create<CopilotStore>((set) => ({
isStreaming: false,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -5,15 +5,10 @@ import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useCopilotStore } from "./copilot-page-store";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
const { state, handlers } = useCopilotPage();
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const {
greetingName,
quickActions,
@@ -40,42 +35,6 @@ export default function CopilotPage() {
onSessionNotFound={handleSessionNotFound}
onStreamingChange={handleStreamingChange}
/>
<Dialog
title="Interrupt current chat?"
styling={{ maxWidth: 300, width: "100%" }}
controlled={{
isOpen: isInterruptModalOpen,
set: (open) => {
if (!open) cancelInterrupt();
},
}}
onClose={cancelInterrupt}
>
<Dialog.Content>
<div className="flex flex-col gap-4">
<Text variant="body">
The current chat response will be interrupted. Are you sure you
want to continue?
</Text>
<Dialog.Footer>
<Button
type="button"
variant="outline"
onClick={cancelInterrupt}
>
Cancel
</Button>
<Button
type="button"
variant="primary"
onClick={confirmInterrupt}
>
Continue
</Button>
</Dialog.Footer>
</div>
</Dialog.Content>
</Dialog>
</div>
);
}

View File

@@ -5,7 +5,6 @@ import {
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import {
Flag,
type FlagValues,
@@ -26,20 +25,12 @@ export function useCopilotPage() {
const queryClient = useQueryClient();
const { user, isLoggedIn, isUserLoading } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
useEffect(() => {
if (isLoggedIn) {
completeStep("VISIT_COPILOT");
}
}, [completeStep, isLoggedIn]);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
const homepageRoute = getHomepageRoute(isChatEnabled);

View File

@@ -0,0 +1,81 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
/**
* SSE Proxy for task stream reconnection.
*
* This endpoint allows clients to reconnect to an ongoing or recently completed
* background task's stream. It replays missed messages from Redis Streams and
* subscribes to live updates if the task is still running.
*
* Client contract:
* 1. When receiving an operation_started event, store the task_id
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
* 3. Messages are replayed from the last_message_id position
* 4. Stream ends when "finish" event is received
*/
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
// Get auth token from server-side session
const token = await getServerAuthToken();
// Build backend URL
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
// Forward request to backend with auth header
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
// Return the SSE stream directly
return new Response(response.body, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"X-Accel-Buffering": "no",
},
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -939,6 +939,63 @@
}
}
},
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -1022,7 +1079,7 @@
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1157,7 +1214,7 @@
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
"operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
@@ -1195,6 +1252,94 @@
}
}
},
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -6168,6 +6313,16 @@
"title": "AccuracyTrendsResponse",
"description": "Response model for accuracy trends and alerts."
},
"ActiveStreamInfo": {
"properties": {
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" }
},
"type": "object",
"required": ["task_id", "last_message_id"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
"AddUserCreditsResponse": {
"properties": {
"new_balance": { "type": "integer", "title": "New Balance" },
@@ -8823,6 +8978,27 @@
],
"title": "OnboardingStep"
},
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"Pagination": {
"properties": {
"total_items": {
@@ -9678,6 +9854,12 @@
"items": { "additionalProperties": true, "type": "object" },
"type": "array",
"title": "Messages"
},
"active_stream": {
"anyOf": [
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
}
},
"type": "object",

View File

@@ -1,7 +1,6 @@
"use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
@@ -25,8 +24,8 @@ export function Chat({
}: ChatProps) {
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const {
session,
messages,
isLoading,
isCreating,
@@ -38,6 +37,21 @@ export function Chat({
startPollingForOperation,
} = useChat({ urlSessionId });
// Extract active stream info for reconnection
const activeStream = (session as { active_stream?: { task_id: string; last_message_id: string } })?.active_stream;
// Debug logging for SSE reconnection
if (session) {
console.info("[SSE-RECONNECT] Session loaded:", {
sessionId,
hasActiveStream: !!activeStream,
activeStream: activeStream ? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
} : null,
});
}
useEffect(() => {
if (!onSessionNotFound) return;
if (!urlSessionId) return;
@@ -53,8 +67,7 @@ export function Chat({
isCreating,
]);
const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
const shouldShowLoader = showLoader && (isLoading || isCreating);
return (
<div className={cn("flex h-full flex-col", className)}>
@@ -66,21 +79,19 @@ export function Chat({
<div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500">
{isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
Loading your chat...
</Text>
</div>
</div>
)}
{/* Error State */}
{error && !isLoading && !isSwitchingSession && (
{error && !isLoading && (
<ChatErrorState error={error} onRetry={createSession} />
)}
{/* Session Content */}
{sessionId && !isLoading && !error && !isSwitchingSession && (
{sessionId && !isLoading && !error && (
<ChatContainer
sessionId={sessionId}
initialMessages={messages}
@@ -88,6 +99,10 @@ export function Chat({
className="flex-1"
onStreamingChange={onStreamingChange}
onOperationStarted={startPollingForOperation}
activeStream={activeStream ? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
} : undefined}
/>
)}
</main>

View File

@@ -0,0 +1,159 @@
# SSE Reconnection Contract for Long-Running Operations
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
## Overview
When a user triggers a long-running operation (like agent generation), the backend:
1. Spawns a background task that survives SSE disconnections
2. Returns an `operation_started` response with a `task_id`
3. Stores stream messages in Redis Streams for replay
Clients can reconnect to the task stream at any time to receive missed messages.
## Client-Side Flow
### 1. Receiving Operation Started
When you receive an `operation_started` tool response:
```typescript
// The response includes a task_id for reconnection
{
type: "operation_started",
tool_name: "generate_agent",
operation_id: "uuid-...",
task_id: "task-uuid-...", // <-- Store this for reconnection
message: "Operation started. You can close this tab."
}
```
### 2. Storing Task Info
Use the chat store to track the active task:
```typescript
import { useChatStore } from "./chat-store";
// When operation_started is received:
useChatStore.getState().setActiveTask(sessionId, {
taskId: response.task_id,
operationId: response.operation_id,
toolName: response.tool_name,
lastMessageId: "0",
});
```
### 3. Reconnecting to a Task
To reconnect (e.g., after page refresh or tab reopen):
```typescript
const { reconnectToTask, getActiveTask } = useChatStore.getState();
// Check if there's an active task for this session
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Reconnect to the task stream
await reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId, // Resume from last position
(chunk) => {
// Handle incoming chunks
console.log("Received chunk:", chunk);
},
);
}
```
### 4. Tracking Message Position
To enable precise replay, update the last message ID as chunks arrive:
```typescript
const { updateTaskLastMessageId } = useChatStore.getState();
function handleChunk(chunk: StreamChunk) {
// If chunk has an index/id, track it
if (chunk.idx !== undefined) {
updateTaskLastMessageId(sessionId, String(chunk.idx));
}
}
```
## API Endpoints
### Task Stream Reconnection
```
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
```
- `taskId`: The task ID from `operation_started`
- `last_message_id`: Last received message index (default: "0" for full replay)
Returns: SSE stream of missed messages + live updates
## Chunk Types
The reconnected stream follows the same Vercel AI SDK protocol:
| Type | Description |
| ----------------------- | ----------------------- |
| `start` | Message lifecycle start |
| `text-delta` | Streaming text content |
| `text-end` | Text block completed |
| `tool-output-available` | Tool result available |
| `finish` | Stream completed |
| `error` | Error occurred |
## Error Handling
If reconnection fails:
1. Check if task still exists (may have expired - default TTL: 1 hour)
2. Fall back to polling the session for final state
3. Show appropriate UI message to user
## Persistence Considerations
For robust reconnection across browser restarts:
```typescript
// Store in localStorage/sessionStorage
const ACTIVE_TASKS_KEY = "chat_active_tasks";
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
tasks[sessionId] = task;
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
}
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
}
```
## Backend Configuration
The following backend settings affect reconnection behavior:
| Setting | Default | Description |
| ------------------- | ------- | ---------------------------------- |
| `stream_ttl` | 3600s | How long streams are kept in Redis |
| `stream_max_length` | 1000 | Max messages per stream |
## Testing
To test reconnection locally:
1. Start a long-running operation (e.g., agent generation)
2. Note the `task_id` from the `operation_started` response
3. Close the browser tab
4. Reopen and call `reconnectToTask` with the saved `task_id`
5. Verify that missed messages are replayed
See the main README for full local development setup.

View File

@@ -0,0 +1,16 @@
/**
* Constants for the chat system.
*
* Centralizes magic strings and values used across chat components.
*/
// LocalStorage keys
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
// Redis Stream IDs
export const INITIAL_MESSAGE_ID = "0";
export const INITIAL_STREAM_ID = "0-0";
// TTL values (in milliseconds)
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour

View File

@@ -1,6 +1,12 @@
"use client";
import { create } from "zustand";
import {
ACTIVE_TASK_TTL_MS,
COMPLETED_STREAM_TTL_MS,
INITIAL_STREAM_ID,
STORAGE_KEY_ACTIVE_TASKS,
} from "./chat-constants";
import type {
ActiveStream,
StreamChunk,
@@ -8,15 +14,64 @@ import type {
StreamResult,
StreamStatus,
} from "./chat-types";
import { executeStream } from "./stream-executor";
import { executeStream, executeTaskReconnect } from "./stream-executor";
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
/**
* Tracks active task info for SSE reconnection.
* When a long-running operation starts, we store this so clients can reconnect
* if the browser tab is closed and reopened.
*/
export interface ActiveTaskInfo {
taskId: string;
sessionId: string;
operationId: string;
toolName: string;
lastMessageId: string; // Last processed message ID for replay (Redis Stream format: "0-0")
startedAt: number;
}
/** Load active tasks from localStorage */
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
if (typeof window === "undefined") return new Map();
try {
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
if (!stored) return new Map();
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
const now = Date.now();
const tasks = new Map<string, ActiveTaskInfo>();
// Filter out expired tasks
for (const [sessionId, task] of Object.entries(parsed)) {
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
tasks.set(sessionId, task);
}
}
return tasks;
} catch {
return new Map();
}
}
/** Save active tasks to localStorage */
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
if (typeof window === "undefined") return;
try {
const obj: Record<string, ActiveTaskInfo> = {};
for (const [sessionId, task] of tasks) {
obj[sessionId] = task;
}
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
} catch {
// Ignore storage errors
}
}
interface ChatStoreState {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
activeSessions: Set<string>;
streamCompleteCallbacks: Set<StreamCompleteCallback>;
/** Active tasks for SSE reconnection - keyed by sessionId */
activeTasks: Map<string, ActiveTaskInfo>;
}
interface ChatStoreActions {
@@ -41,6 +96,24 @@ interface ChatStoreActions {
unregisterActiveSession: (sessionId: string) => void;
isSessionActive: (sessionId: string) => boolean;
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
/** Track active task for SSE reconnection */
setActiveTask: (
sessionId: string,
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
) => void;
/** Get active task for a session */
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
/** Clear active task when operation completes */
clearActiveTask: (sessionId: string) => void;
/** Reconnect to an existing task stream */
reconnectToTask: (
sessionId: string,
taskId: string,
lastMessageId?: string,
onChunk?: (chunk: StreamChunk) => void,
) => Promise<void>;
/** Update last message ID for a task (for tracking replay position) */
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
}
type ChatStore = ChatStoreState & ChatStoreActions;
@@ -64,18 +137,79 @@ function cleanupExpiredStreams(
const now = Date.now();
const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
cleaned.delete(sessionId);
}
}
return cleaned;
}
/**
* Clean up an existing stream for a session and move it to completed streams.
* Returns updated maps for both active and completed streams.
*/
function cleanupExistingStream(
sessionId: string,
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
callbacks: Set<StreamCompleteCallback>,
): {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
} {
const newActiveStreams = new Map(activeStreams);
let newCompletedStreams = new Map(completedStreams);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming" ? "completed" : existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
return { activeStreams: newActiveStreams, completedStreams: newCompletedStreams };
}
/**
* Create a new active stream with initial state.
*/
function createActiveStream(
sessionId: string,
onChunk?: (chunk: StreamChunk) => void,
): ActiveStream {
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
return {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
}
export const useChatStore = create<ChatStore>((set, get) => ({
activeStreams: new Map(),
completedStreams: new Map(),
activeSessions: new Set(),
streamCompleteCallbacks: new Set(),
activeTasks: loadPersistedTasks(),
startStream: async function startStream(
sessionId,
@@ -85,45 +219,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk,
) {
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
// Clean up any existing stream for this session
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
cleanupExistingStream(
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
@@ -286,4 +394,134 @@ export const useChatStore = create<ChatStore>((set, get) => ({
set({ streamCompleteCallbacks: cleanedCallbacks });
};
},
setActiveTask: function setActiveTask(sessionId, taskInfo) {
const state = get();
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...taskInfo,
sessionId,
startedAt: Date.now(),
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
getActiveTask: function getActiveTask(sessionId) {
return get().activeTasks.get(sessionId);
},
clearActiveTask: function clearActiveTask(sessionId) {
const state = get();
if (!state.activeTasks.has(sessionId)) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
reconnectToTask: async function reconnectToTask(
sessionId,
taskId,
lastMessageId = INITIAL_STREAM_ID,
onChunk,
) {
console.info("[SSE-RECONNECT] reconnectToTask called:", {
sessionId,
taskId,
lastMessageId,
});
const state = get();
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
cleanupExistingStream(
sessionId,
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream for reconnection
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
console.info("[SSE-RECONNECT] Starting executeTaskReconnect...");
try {
await executeTaskReconnect(stream, taskId, lastMessageId);
console.info("[SSE-RECONNECT] executeTaskReconnect completed:", {
sessionId,
taskId,
streamStatus: stream.status,
});
} finally {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
// Clear active task on completion
const taskState = get();
const newActiveTasks = new Map(taskState.activeTasks);
const hadActiveTask = newActiveTasks.has(sessionId);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
if (hadActiveTask) {
console.info(
`[ChatStore] Cleared active task for session ${sessionId} ` +
`(stream status: ${stream.status})`,
);
}
}
}
}
}
},
updateTaskLastMessageId: function updateTaskLastMessageId(
sessionId,
lastMessageId,
) {
const state = get();
const task = state.activeTasks.get(sessionId);
if (!task) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...task,
lastMessageId,
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
}));

View File

@@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
export interface StreamChunk {
type:
| "stream_start"
| "text_chunk"
| "text_ended"
| "tool_call"
@@ -15,6 +16,7 @@ export interface StreamChunk {
| "error"
| "usage"
| "stream_end";
taskId?: string; // Task ID for SSE reconnection
timestamp?: string;
content?: string;
message?: string;
@@ -41,7 +43,7 @@ export interface StreamChunk {
}
export type VercelStreamChunk =
| { type: "start"; messageId: string }
| { type: "start"; messageId: string; taskId?: string }
| { type: "finish" }
| { type: "text-start"; id: string }
| { type: "text-delta"; id: string; delta: string }
@@ -92,3 +94,67 @@ export interface StreamResult {
}
export type StreamCompleteCallback = (sessionId: string) => void;
// Type guards for message types
/**
* Check if a message has a toolId property.
*/
export function hasToolId<T extends { type: string }>(
msg: T,
): msg is T & { toolId: string } {
return "toolId" in msg && typeof (msg as Record<string, unknown>).toolId === "string";
}
/**
* Check if a message has an operationId property.
*/
export function hasOperationId<T extends { type: string }>(
msg: T,
): msg is T & { operationId: string } {
return (
"operationId" in msg &&
typeof (msg as Record<string, unknown>).operationId === "string"
);
}
/**
* Check if a message has a toolCallId property.
*/
export function hasToolCallId<T extends { type: string }>(
msg: T,
): msg is T & { toolCallId: string } {
return (
"toolCallId" in msg &&
typeof (msg as Record<string, unknown>).toolCallId === "string"
);
}
/**
* Check if a message is an operation message type.
*/
export function isOperationMessage<T extends { type: string }>(
msg: T,
): msg is T & {
type: "operation_started" | "operation_pending" | "operation_in_progress";
} {
return (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
);
}
/**
* Get the tool ID from a message if available.
* Checks toolId, operationId, and toolCallId properties.
*/
export function getToolIdFromMessage<T extends { type: string }>(
msg: T,
): string | undefined {
const record = msg as Record<string, unknown>;
if (typeof record.toolId === "string") return record.toolId;
if (typeof record.operationId === "string") return record.operationId;
if (typeof record.toolCallId === "string") return record.toolCallId;
return undefined;
}

View File

@@ -17,6 +17,11 @@ export interface ChatContainerProps {
className?: string;
onStreamingChange?: (isStreaming: boolean) => void;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
};
}
export function ChatContainer({
@@ -26,6 +31,7 @@ export function ChatContainer({
className,
onStreamingChange,
onOperationStarted,
activeStream,
}: ChatContainerProps) {
const {
messages,
@@ -41,6 +47,7 @@ export function ChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
});
useEffect(() => {

View File

@@ -34,6 +34,22 @@ export function createStreamEventDispatcher(
}
switch (chunk.type) {
case "stream_start":
// Store task ID for SSE reconnection
if (chunk.taskId && deps.onActiveTaskStarted) {
console.info("[ChatStream] Stream started with task ID:", {
sessionId: deps.sessionId,
taskId: chunk.taskId,
});
deps.onActiveTaskStarted({
taskId: chunk.taskId,
operationId: chunk.taskId, // Use taskId as operationId for chat streams
toolName: "chat",
toolCallId: "chat_stream",
});
}
break;
case "text_chunk":
handleTextChunk(chunk, deps);
break;
@@ -56,7 +72,8 @@ export function createStreamEventDispatcher(
break;
case "stream_end":
console.info("[ChatStream] Stream ended:", {
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
console.info("[SSE-RECONNECT] Stream ended:", {
sessionId: deps.sessionId,
hasResponse: deps.hasResponseRef.current,
chunkCount: deps.streamingChunksRef.current.length,

View File

@@ -23,6 +23,12 @@ export interface HandlerDependencies {
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
sessionId: string;
onOperationStarted?: () => void;
onActiveTaskStarted?: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
}
export function isRegionBlockedError(chunk: StreamChunk): boolean {
@@ -164,9 +170,19 @@ export function handleToolResponse(
}
return;
}
// Trigger polling when operation_started is received
// Trigger polling and store task info when operation_started is received
if (responseMessage.type === "operation_started") {
deps.onOperationStarted?.();
// Store task info for SSE reconnection if taskId is present
const taskId = (responseMessage as any).taskId;
if (taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId,
operationId: (responseMessage as any).operationId || "",
toolName: (responseMessage as any).toolName || "",
toolCallId: (responseMessage as any).toolId || "",
});
}
}
deps.setMessages((prev) => {
@@ -205,8 +221,10 @@ export function handleStreamEnd(
_chunk: StreamChunk,
deps: HandlerDependencies,
) {
console.info("[SSE-RECONNECT] handleStreamEnd called, resetting streaming state");
const completedContent = deps.streamingChunksRef.current.join("");
if (!completedContent.trim() && !deps.hasResponseRef.current) {
console.info("[SSE-RECONNECT] No content received, adding placeholder message");
deps.setMessages((prev) => [
...prev,
{
@@ -245,10 +263,14 @@ export function handleStreamEnd(
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
const errorMessage = chunk.message || chunk.content || "An error occurred";
console.error("Stream error:", errorMessage);
console.error("[ChatStream] Stream error:", errorMessage, {
sessionId: deps.sessionId,
chunk,
});
if (isRegionBlockedError(chunk)) {
deps.setIsRegionBlockedModalOpen(true);
}
console.info("[ChatStream] Resetting streaming state due to error");
deps.setIsStreamingInitiated(false);
deps.setHasTextChunks(false);
deps.setStreamingChunks([]);

View File

@@ -349,6 +349,7 @@ export function parseToolResponse(
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
message:
(parsedResult.message as string) ||
"Operation started. You can close this tab.",

View File

@@ -1,10 +1,16 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import { useEffect, useMemo, useRef, useState } from "react";
import { INITIAL_STREAM_ID } from "../../chat-constants";
import { useChatStore } from "../../chat-store";
import { toast } from "sonner";
import { useChatStream } from "../../useChatStream";
import { usePageContext } from "../../usePageContext";
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
import {
getToolIdFromMessage,
hasToolId,
isOperationMessage,
} from "../../chat-types";
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
import {
createUserMessage,
@@ -14,6 +20,46 @@ import {
processInitialMessages,
} from "./helpers";
/**
* Dependencies for creating a stream event dispatcher.
* Extracted to allow helper function creation.
*/
interface DispatcherDeps {
setHasTextChunks: React.Dispatch<React.SetStateAction<boolean>>;
setStreamingChunks: React.Dispatch<React.SetStateAction<string[]>>;
streamingChunksRef: React.MutableRefObject<string[]>;
hasResponseRef: React.MutableRefObject<boolean>;
setMessages: React.Dispatch<React.SetStateAction<ChatMessageData[]>>;
setIsRegionBlockedModalOpen: React.Dispatch<React.SetStateAction<boolean>>;
sessionId: string;
setIsStreamingInitiated: React.Dispatch<React.SetStateAction<boolean>>;
onOperationStarted?: () => void;
onActiveTaskStarted: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
}
/**
* Create a stream event dispatcher with the given dependencies.
*/
function createDispatcher(deps: DispatcherDeps) {
return createStreamEventDispatcher({
setHasTextChunks: deps.setHasTextChunks,
setStreamingChunks: deps.setStreamingChunks,
streamingChunksRef: deps.streamingChunksRef,
hasResponseRef: deps.hasResponseRef,
setMessages: deps.setMessages,
setIsRegionBlockedModalOpen: deps.setIsRegionBlockedModalOpen,
sessionId: deps.sessionId,
setIsStreamingInitiated: deps.setIsStreamingInitiated,
onOperationStarted: deps.onOperationStarted,
onActiveTaskStarted: deps.onActiveTaskStarted,
});
}
// Helper to generate deduplication key for a message
function getMessageKey(msg: ChatMessageData): string {
if (msg.type === "message") {
@@ -24,13 +70,11 @@ function getMessageKey(msg: ChatMessageData): string {
} else if (msg.type === "tool_call") {
return `toolcall:${msg.toolId}`;
} else if (msg.type === "tool_response") {
return `toolresponse:${(msg as any).toolId}`;
} else if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
const toolId = hasToolId(msg) ? msg.toolId : "";
return `toolresponse:${toolId}`;
} else if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg) || "";
return `op:${toolId}:${msg.toolName}`;
} else {
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
}
@@ -41,6 +85,11 @@ interface Args {
initialMessages: SessionDetailResponse["messages"];
initialPrompt?: string;
onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
};
}
export function useChatContainer({
@@ -48,6 +97,7 @@ export function useChatContainer({
initialMessages,
initialPrompt,
onOperationStarted,
activeStream,
}: Args) {
const [messages, setMessages] = useState<ChatMessageData[]>([]);
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
@@ -65,30 +115,195 @@ export function useChatContainer({
} = useChatStream();
const activeStreams = useChatStore((s) => s.activeStreams);
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
const setActiveTask = useChatStore((s) => s.setActiveTask);
const getActiveTask = useChatStore((s) => s.getActiveTask);
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
const isStreaming = isStreamingInitiated || hasTextChunks;
// Track whether we've already connected to this activeStream to avoid duplicate connections
const connectedActiveStreamRef = useRef<string | null>(null);
// Callback to store active task info for SSE reconnection
function handleActiveTaskStarted(taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) {
if (!sessionId) return;
setActiveTask(sessionId, {
taskId: taskInfo.taskId,
operationId: taskInfo.operationId,
toolName: taskInfo.toolName,
lastMessageId: INITIAL_STREAM_ID,
});
}
useEffect(
function handleSessionChange() {
if (sessionId === previousSessionIdRef.current) return;
const isSessionChange = sessionId !== previousSessionIdRef.current;
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
console.info("[SSE-RECONNECT] handleSessionChange effect running:", {
sessionId,
previousSessionId: previousSessionIdRef.current,
isSessionChange,
hasActiveStream: !!activeStream,
activeStreamTaskId: activeStream?.taskId,
connectedActiveStream: connectedActiveStreamRef.current,
});
// Handle session change - reset state
if (isSessionChange) {
console.info("[SSE-RECONNECT] Session changed, resetting state");
const prevSession = previousSessionIdRef.current;
if (prevSession) {
stopStreaming(prevSession);
}
previousSessionIdRef.current = sessionId;
connectedActiveStreamRef.current = null; // Reset connected stream tracker
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
}
previousSessionIdRef.current = sessionId;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
if (!sessionId) return;
if (!sessionId) {
console.info("[SSE-RECONNECT] No sessionId, skipping reconnection check");
return;
}
const activeStream = activeStreams.get(sessionId);
if (!activeStream || activeStream.status !== "streaming") return;
// Priority 1: Check if server told us there's an active stream (most authoritative)
// Also handles the case where activeStream arrives after initial session load
if (activeStream) {
// Skip if we've already connected to this exact stream
// Check and set immediately to prevent race conditions from effect re-runs
const streamKey = `${sessionId}:${activeStream.taskId}`;
if (connectedActiveStreamRef.current === streamKey) {
console.info(
"[SSE-RECONNECT] Already connected to this stream, skipping:",
{ streamKey },
);
return;
}
const dispatcher = createStreamEventDispatcher({
// Also skip if there's already an active stream for this session in the store
// (handles case where effect re-runs due to activeStreams state change)
const existingStream = activeStreams.get(sessionId);
if (existingStream && existingStream.status === "streaming") {
console.info(
"[SSE-RECONNECT] Active stream already exists in store, skipping:",
{ sessionId, status: existingStream.status },
);
connectedActiveStreamRef.current = streamKey;
return;
}
// Set immediately after check to prevent race conditions
connectedActiveStreamRef.current = streamKey;
console.info(
"[SSE-RECONNECT] Server reports active stream, initiating reconnection:",
{
sessionId,
taskId: activeStream.taskId,
lastMessageId: activeStream.lastMessageId,
streamKey,
},
);
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
// Store this as the active task for future reconnects
setActiveTask(sessionId, {
taskId: activeStream.taskId,
operationId: activeStream.taskId,
toolName: "chat",
lastMessageId: activeStream.lastMessageId,
});
// Reconnect to the task stream
console.info("[SSE-RECONNECT] Calling reconnectToTask...");
reconnectToTask(
sessionId,
activeStream.taskId,
activeStream.lastMessageId,
dispatcher,
);
return;
}
// Only check localStorage/in-memory on session change, not on every render
if (!isSessionChange) {
console.info(
"[SSE-RECONNECT] No active stream and not a session change, skipping fallbacks",
);
return;
}
// Priority 2: Check localStorage for active task (client-side state)
console.info("[SSE-RECONNECT] Checking localStorage for active task...");
const activeTask = getActiveTask(sessionId);
if (activeTask) {
console.info(
"[SSE-RECONNECT] Found active task in localStorage, attempting reconnect:",
{
sessionId,
taskId: activeTask.taskId,
lastMessageId: activeTask.lastMessageId,
},
);
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
// Reconnect to the task stream
console.info("[SSE-RECONNECT] Calling reconnectToTask from localStorage...");
reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId,
dispatcher,
);
return;
} else {
console.info("[SSE-RECONNECT] No active task in localStorage");
}
// Priority 3: Check for an in-memory active stream (same-tab scenario)
console.info("[SSE-RECONNECT] Checking in-memory active streams...");
const inMemoryStream = activeStreams.get(sessionId);
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
console.info("[SSE-RECONNECT] No in-memory active stream found:", {
hasStream: !!inMemoryStream,
status: inMemoryStream?.status,
});
return;
}
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
@@ -98,6 +313,7 @@ export function useChatContainer({
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
@@ -110,6 +326,10 @@ export function useChatContainer({
activeStreams,
subscribeToStream,
onOperationStarted,
getActiveTask,
reconnectToTask,
activeStream,
setActiveTask,
],
);
@@ -124,7 +344,7 @@ export function useChatContainer({
msg.type === "agent_carousel" ||
msg.type === "execution_started"
) {
const toolId = (msg as any).toolId;
const toolId = hasToolId(msg) ? msg.toolId : undefined;
if (toolId) {
ids.add(toolId);
}
@@ -141,12 +361,8 @@ export function useChatContainer({
setMessages((prev) => {
const filtered = prev.filter((msg) => {
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed
}
@@ -174,12 +390,8 @@ export function useChatContainer({
// Filter local messages: remove duplicates and completed operation messages
const newLocalMessages = messages.filter((msg) => {
// Remove operation messages for completed tools
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (toolId && completedToolIds.has(toolId)) {
return false;
}
@@ -215,7 +427,7 @@ export function useChatContainer({
setIsStreamingInitiated(true);
hasResponseRef.current = false;
const dispatcher = createStreamEventDispatcher({
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
@@ -225,6 +437,7 @@ export function useChatContainer({
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
try {

View File

@@ -111,6 +111,7 @@ export type ChatMessageData =
toolName: string;
toolId: string;
operationId: string;
taskId?: string; // For SSE reconnection
message: string;
timestamp?: string | Date;
}

View File

@@ -1,3 +1,4 @@
import { INITIAL_MESSAGE_ID } from "./chat-constants";
import type {
ActiveStream,
StreamChunk,
@@ -10,8 +11,14 @@ import {
parseSSELine,
} from "./stream-utils";
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
stream.chunks.push(chunk);
function notifySubscribers(
stream: ActiveStream,
chunk: StreamChunk,
skipStore = false,
) {
if (!skipStore) {
stream.chunks.push(chunk);
}
for (const callback of stream.onChunkCallbacks) {
try {
callback(chunk);
@@ -21,42 +28,133 @@ function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
}
}
export async function executeStream(
stream: ActiveStream,
message: string,
isUserMessage: boolean,
context?: { url: string; content: string },
retryCount: number = 0,
/**
* Options for stream execution.
*/
interface StreamExecutionOptions {
/** The active stream state object */
stream: ActiveStream;
/** Execution mode: 'new' for new stream, 'reconnect' for task reconnection */
mode: "new" | "reconnect";
/** Message content (required for 'new' mode) */
message?: string;
/** Whether this is a user message (for 'new' mode) */
isUserMessage?: boolean;
/** Optional context for the message (for 'new' mode) */
context?: { url: string; content: string };
/** Task ID (required for 'reconnect' mode) */
taskId?: string;
/** Last message ID for replay (for 'reconnect' mode) */
lastMessageId?: string;
/** Current retry count (internal use) */
retryCount?: number;
}
/**
* Unified stream execution function that handles both new streams and task reconnection.
*
* For new streams:
* - Posts a message to create a new chat stream
* - Reads SSE chunks and notifies subscribers
*
* For reconnection:
* - Connects to an existing task stream
* - Replays messages from lastMessageId position
* - Allows resumption of long-running operations
*/
async function executeStreamInternal(
options: StreamExecutionOptions,
): Promise<void> {
const {
stream,
mode,
message,
isUserMessage,
context,
taskId,
lastMessageId = INITIAL_MESSAGE_ID,
retryCount = 0,
} = options;
const { sessionId, abortController } = stream;
const isReconnect = mode === "reconnect";
const logPrefix = isReconnect ? "[SSE-RECONNECT]" : "[StreamExecutor]";
if (isReconnect) {
console.info(`${logPrefix} executeStream starting:`, {
taskId,
lastMessageId,
retryCount,
});
}
try {
const url = `/api/chat/sessions/${sessionId}/stream`;
const body = JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
});
// Build URL and request options based on mode
let url: string;
let fetchOptions: RequestInit;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body,
signal: abortController.signal,
});
if (isReconnect) {
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
fetchOptions = {
method: "GET",
headers: {
Accept: "text/event-stream",
},
signal: abortController.signal,
};
console.info(`${logPrefix} Fetching task stream:`, { url });
} else {
url = `/api/chat/sessions/${sessionId}/stream`;
fetchOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body: JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
}),
signal: abortController.signal,
};
}
const response = await fetch(url, fetchOptions);
if (isReconnect) {
console.info(`${logPrefix} Task stream response:`, {
status: response.status,
ok: response.ok,
});
}
if (!response.ok) {
const errorText = await response.text();
throw new Error(errorText || `HTTP ${response.status}`);
if (isReconnect) {
console.error(`${logPrefix} Task stream error response:`, {
status: response.status,
errorText,
});
}
// For reconnect: don't retry on 404/403 (permanent errors)
const isPermanentError =
isReconnect && (response.status === 404 || response.status === 403);
const error = new Error(errorText || `HTTP ${response.status}`);
(error as Error & { status?: number }).status = response.status;
(error as Error & { isPermanent?: boolean }).isPermanent =
isPermanentError;
throw error;
}
if (!response.body) {
throw new Error("Response body is null");
}
if (isReconnect) {
console.info(`${logPrefix} Task stream connected, reading chunks...`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
@@ -65,6 +163,11 @@ export async function executeStream(
const { done, value } = await reader.read();
if (done) {
if (isReconnect) {
console.info(
`${logPrefix} Task stream reader done (connection closed)`,
);
}
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -78,6 +181,9 @@ export async function executeStream(
const data = parseSSELine(line);
if (data !== null) {
if (data === "[DONE]") {
if (isReconnect) {
console.info(`${logPrefix} Task stream received [DONE] signal`);
}
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -90,14 +196,30 @@ export async function executeStream(
const chunk = normalizeStreamChunk(rawChunk);
if (!chunk) continue;
// Log first few chunks for debugging (reconnect mode only)
if (isReconnect && stream.chunks.length < 3) {
console.info(`${logPrefix} Task stream chunk received:`, {
type: chunk.type,
chunkIndex: stream.chunks.length,
});
}
notifySubscribers(stream, chunk);
if (chunk.type === "stream_end") {
if (isReconnect) {
console.info(
`${logPrefix} Task stream completed via stream_end chunk`,
);
}
stream.status = "completed";
return;
}
if (chunk.type === "error") {
if (isReconnect) {
console.error(`${logPrefix} Task stream error chunk:`, chunk);
}
stream.status = "error";
stream.error = new Error(
chunk.message || chunk.content || "Stream error",
@@ -105,7 +227,7 @@ export async function executeStream(
return;
}
} catch (err) {
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
console.warn(`${logPrefix} Failed to parse SSE chunk:`, err);
}
}
}
@@ -117,18 +239,27 @@ export async function executeStream(
return;
}
if (retryCount < MAX_RETRIES) {
// Check if this is a permanent error (404/403) that shouldn't be retried
const isPermanentError =
err instanceof Error &&
(err as Error & { isPermanent?: boolean }).isPermanent;
if (!isPermanentError && retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
`${logPrefix} Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeStream(
stream,
message,
isUserMessage,
context,
retryCount + 1,
return executeStreamInternal({
...options,
retryCount: retryCount + 1,
});
}
// Log permanent errors differently for debugging
if (isPermanentError) {
console.log(
`${logPrefix} Stream failed permanently (task not found or access denied): ${(err as Error).message}`,
);
}
@@ -140,3 +271,52 @@ export async function executeStream(
});
}
}
/**
* Execute a new chat stream.
*
* Posts a message to create a new stream and reads SSE responses.
*/
export async function executeStream(
stream: ActiveStream,
message: string,
isUserMessage: boolean,
context?: { url: string; content: string },
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "new",
message,
isUserMessage,
context,
retryCount,
});
}
/**
* Reconnect to an existing task stream.
*
* This is used when a client wants to resume receiving updates from a
* long-running background task. Messages are replayed from the last_message_id
* position, allowing clients to catch up on missed events.
*
* @param stream - The active stream state
* @param taskId - The task ID to reconnect to
* @param lastMessageId - The last message ID received (for replay)
* @param retryCount - Current retry count
*/
export async function executeTaskReconnect(
stream: ActiveStream,
taskId: string,
lastMessageId: string = INITIAL_MESSAGE_ID,
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "reconnect",
taskId,
lastMessageId,
retryCount,
});
}

View File

@@ -28,7 +28,8 @@ export function normalizeStreamChunk(
switch (chunk.type) {
case "text-delta":
return { type: "text_chunk", content: chunk.delta };
// Backend sends "content", Vercel AI SDK sends "delta"
return { type: "text_chunk", content: chunk.delta || chunk.content };
case "text-end":
return { type: "text_ended" };
case "tool-input-available":
@@ -63,6 +64,10 @@ export function normalizeStreamChunk(
case "finish":
return { type: "stream_end" };
case "start":
// Start event with optional taskId for reconnection
return chunk.taskId
? { type: "stream_start", taskId: chunk.taskId }
: null;
case "text-start":
return null;
case "tool-input-start":