mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 10:55:14 -05:00
Compare commits
16 Commits
ntindle/fi
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
070d56166c | ||
|
|
ef3fab57fd | ||
|
|
e812ee9265 | ||
|
|
eb3872d78b | ||
|
|
2cdfe90c56 | ||
|
|
1081590384 | ||
|
|
d1da7fe5da | ||
|
|
11e27cfdcf | ||
|
|
0be5fedc86 | ||
|
|
f2e81648b5 | ||
|
|
bb608ea60d | ||
|
|
46af3b94f2 | ||
|
|
083cceca0f | ||
|
|
06758adefd | ||
|
|
c01c29a059 | ||
|
|
d738059da8 |
@@ -0,0 +1,303 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer that listens for completion notifications
|
||||
from external services (like Agent Generator) and triggers the appropriate
|
||||
stream registry and chat service updates.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be redelivered to another consumer
|
||||
# or can be claimed after timeout
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
|
||||
def serialize_result(result: dict | str | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize (dict, string, or None)
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result:
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
return '{"status": "completed"}'
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"[COMPLETION] Cannot save agent: no user_id in task"
|
||||
)
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": agent_json, # Include the JSON so user can retry
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output
|
||||
result_output = result if result else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": result_str},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message for session {task.session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": error_response.model_dump_json()},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -44,6 +44,48 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
task_pubsub_prefix: str = Field(
|
||||
default="chat:task:pubsub:",
|
||||
description="Prefix for task pub/sub channel names",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
@@ -82,6 +124,14 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -55,6 +59,13 @@ class CreateSessionResponse(BaseModel):
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
@@ -63,6 +74,7 @@ class SessionDetailResponse(BaseModel):
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -81,6 +93,14 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -166,13 +186,14 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -180,10 +201,45 @@ async def get_session(
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
logger.info(f"[SSE-RECONNECT] Checking for active stream in session {session_id}")
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
original_count = len(messages)
|
||||
messages = messages[:-1]
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Filtered out in-progress assistant message "
|
||||
f"(was {original_count} messages, now {len(messages)})"
|
||||
)
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
|
||||
f"task_id={active_task.task_id}, status={active_task.status}, "
|
||||
f"last_message_id=0-0 (replay from start)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")
|
||||
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
f"roles={[m.get('role') for m in messages]}, "
|
||||
f"has_active_stream={active_stream_info is not None}"
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
@@ -192,6 +248,7 @@ async def get_session(
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -211,49 +268,136 @@ async def stream_chat_post(
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Created stream task for reconnection support: "
|
||||
f"task_id={task_id}, session_id={session_id}"
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
# Write to Redis (subscribers will receive via pub/sub or polling)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
logger.info(
|
||||
"[SSE-RECONNECT] Background AI generation completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"task_id": task_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SSE-RECONNECT] Error in background AI generation for session "
|
||||
f"{session_id}: {e}"
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
logger.info(f"[SSE-RECONNECT] Started background AI generation task for {task_id}")
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.error(f"Failed to subscribe to task {task_id}")
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] SSE subscriber received finish for task {task_id}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
# Client disconnected - that's fine, background task continues
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] SSE client disconnected for task {task_id}, "
|
||||
f"background generation continues"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -366,6 +510,207 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Client reconnecting to task stream: "
|
||||
f"task_id={task_id}, last_message_id={last_message_id}"
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.warning(
|
||||
f"[SSE-RECONNECT] Task not found or access denied: task_id={task_id}"
|
||||
)
|
||||
raise NotFoundError(f"Task {task_id} not found or access denied.")
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Successfully subscribed to task stream: task_id={task_id}"
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
chunk_count = 0
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"Task stream completed for task {task_id}, "
|
||||
f"chunk_count={chunk_count}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -1610,8 +1611,9 @@ async def _yield_tool_call(
|
||||
)
|
||||
return
|
||||
|
||||
# Generate operation ID
|
||||
# Generate operation ID and task ID
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
task_id = str(uuid_module.uuid4())
|
||||
|
||||
# Build a user-friendly message based on tool and arguments
|
||||
if tool_name == "create_agent":
|
||||
@@ -1654,6 +1656,16 @@ async def _yield_tool_call(
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
try:
|
||||
# Create task in stream registry for SSE reconnection support
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Save assistant message with tool_call FIRST (required by LLM)
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
@@ -1675,23 +1687,27 @@ async def _yield_tool_call(
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||
f"in session {session.session_id}"
|
||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||
f"for tool {tool_name} in session {session.session_id}"
|
||||
)
|
||||
|
||||
# Store task reference in module-level set to prevent GC before completion
|
||||
task = asyncio.create_task(
|
||||
_execute_long_running_tool(
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Associate the asyncio task with the stream registry task
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
except Exception as e:
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
@@ -1709,6 +1725,11 @@ async def _yield_tool_call(
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
@@ -1722,6 +1743,7 @@ async def _yield_tool_call(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id, # Include task_id for SSE reconnection
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
@@ -1791,6 +1813,9 @@ async def _execute_long_running_tool(
|
||||
|
||||
This function runs independently of the SSE connection, so the operation
|
||||
survives if the user closes their browser tab.
|
||||
|
||||
NOTE: This is the legacy function without stream registry support.
|
||||
Use _execute_long_running_tool_with_streaming for new implementations.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
@@ -1843,6 +1868,128 @@ async def _execute_long_running_tool(
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _execute_long_running_tool_with_streaming(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
||||
|
||||
This function runs independently of the SSE connection, publishes progress
|
||||
to the stream registry, and survives if the user closes their browser tab.
|
||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||
|
||||
If the external service returns a 202 Accepted (async), this function exits
|
||||
early and lets the RabbitMQ completion consumer handle the rest.
|
||||
"""
|
||||
# Track whether we delegated to async processing - if so, the RabbitMQ
|
||||
# completion consumer will handle cleanup, not us
|
||||
delegated_to_async = False
|
||||
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
enriched_parameters = {
|
||||
**parameters,
|
||||
"_operation_id": operation_id,
|
||||
"_task_id": task_id,
|
||||
}
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=enriched_parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Check if the tool result indicates async processing
|
||||
# (e.g., Agent Generator returned 202 Accepted)
|
||||
try:
|
||||
result_data = orjson.loads(result.output) if result.output else {}
|
||||
if result_data.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Tool {tool_name} delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||
f"RabbitMQ completion consumer will handle the rest."
|
||||
)
|
||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||
# The RabbitMQ consumer will handle everything when the external
|
||||
# service completes and publishes to the queue
|
||||
delegated_to_async = True
|
||||
return
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not async - continue normally
|
||||
|
||||
# Publish tool result to stream registry
|
||||
await stream_registry.publish_chunk(task_id, result)
|
||||
|
||||
# Update the pending message with result
|
||||
result_str = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else orjson.dumps(result.output).decode("utf-8")
|
||||
)
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Background tool {tool_name} completed for session {session_id} "
|
||||
f"(task_id={task_id})"
|
||||
)
|
||||
|
||||
# Generate LLM continuation and stream chunks to registry
|
||||
await _generate_llm_continuation_with_streaming(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=str(e)),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
finally:
|
||||
# Only cleanup if we didn't delegate to async processing
|
||||
# For async path, the RabbitMQ completion consumer handles cleanup
|
||||
if not delegated_to_async:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
@@ -1969,3 +2116,128 @@ async def _generate_llm_continuation(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def _generate_llm_continuation_with_streaming(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Generate an LLM response with streaming to the stream registry.
|
||||
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. Chunks are published to the stream registry
|
||||
so reconnecting clients can receive them.
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Publish start event
|
||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||
|
||||
# Stream the response
|
||||
stream = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assistant_content = ""
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
delta = chunk.choices[0].delta.content
|
||||
assistant_content += delta
|
||||
# Publish delta to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamTextDelta(id=text_block_id, delta=delta),
|
||||
)
|
||||
|
||||
# Publish end events
|
||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||
|
||||
if assistant_content:
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated streaming LLM continuation for session {session_id} "
|
||||
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Streaming LLM continuation returned empty response for {session_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||
)
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
@@ -0,0 +1,614 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It uses Redis for all state management (no in-memory state), making
|
||||
pods stateless and horizontally scalable.
|
||||
|
||||
Architecture:
|
||||
- Redis Stream: Persists all messages for replay
|
||||
- Redis Pub/Sub: Real-time delivery to subscribers
|
||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||
|
||||
Subscribers:
|
||||
1. Replay missed messages from Redis Stream
|
||||
2. Subscribe to pub/sub channel for live updates
|
||||
3. No in-memory state required on the subscribing pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{config.task_meta_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{config.task_stream_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{config.task_op_prefix}{operation_id}"
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task real-time delivery."""
|
||||
return f"{config.task_pubsub_prefix}{task_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Created task {task_id} for session {session_id} in Redis"
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream and pub/sub channel.
|
||||
|
||||
All delivery is via Redis - no in-memory state.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence/replay
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Publish to pub/sub for real-time delivery
|
||||
await redis.publish(pubsub_channel, chunk_json)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.warning(f"[SSE-RECONNECT] Task {task_id} not found in Redis")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
|
||||
logger.info(f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}")
|
||||
|
||||
# Validate ownership
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id}: replayed {replayed_count} messages "
|
||||
f"(last_id={replay_last_id})"
|
||||
)
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
|
||||
)
|
||||
asyncio.create_task(_stream_listener(task_id, subscriber_queue, replay_last_id))
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} is {task_status}, adding finish marker"
|
||||
)
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||
when messages are published during the gap between replay and subscription.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
current_id = last_replayed_id
|
||||
|
||||
logger.debug(
|
||||
f"[SSE-RECONNECT] Stream listener started for task {task_id}, "
|
||||
f"from ID {current_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
if status and status != "running":
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} no longer running "
|
||||
f"(status={status}), stopping listener"
|
||||
)
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
break
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} finished "
|
||||
"via stream"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[SSE-RECONNECT] Stream listener cancelled for task {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
subscriber_queue.put_nowait(StreamFinish())
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> None:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
"""
|
||||
# Publish finish event (goes to Redis Stream + pub/sub)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Update Redis metadata
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
await redis.hset(meta_key, "status", status) # type: ignore[misc]
|
||||
|
||||
# Clean up local task reference if exists
|
||||
_local_tasks.pop(task_id, None)
|
||||
|
||||
logger.info(f"[SSE-RECONNECT] Marked task {task_id} as {status}")
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] find_task_by_operation_id: "
|
||||
f"op_key={op_key}, task_id_from_redis={task_id!r}"
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
logger.info(f"[SSE-RECONNECT] No task_id found for operation {operation_id}")
|
||||
return None
|
||||
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
logger.info(f"[SSE-RECONNECT] Looking up task by task_id={task_id_str}")
|
||||
return await get_task(task_id_str)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID from Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] get_task: meta_key={meta_key}, "
|
||||
f"meta_keys={list(meta.keys()) if meta else 'empty'}, "
|
||||
f"meta={meta}"
|
||||
)
|
||||
|
||||
if not meta:
|
||||
logger.info(f"[SSE-RECONNECT] No metadata found for task {task_id}")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task = ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] get_task returning: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, operation_id={task.operation_id}"
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
async def get_active_task_for_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[ActiveTask | None, str]:
|
||||
"""Get the active (running) task for a session, if any.
|
||||
|
||||
Scans Redis for tasks matching the session_id with status="running".
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
user_id: User ID for ownership validation (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||
"""
|
||||
logger.info(f"[SSE-RECONNECT] Looking for active task for session {session_id}")
|
||||
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Scan Redis for task metadata keys
|
||||
cursor = 0
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task_session_id = meta.get("session_id", "")
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
task_id = meta.get("task_id", "")
|
||||
|
||||
# Log tasks found for this session
|
||||
if task_session_id == session_id:
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Found task for session: "
|
||||
f"task_id={task_id}, status={task_status}"
|
||||
)
|
||||
|
||||
if task_session_id == session_id and task_status == "running":
|
||||
# Validate ownership
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.info(f"[SSE-RECONNECT] Task {task_id} ownership mismatch")
|
||||
continue
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
try:
|
||||
messages = await redis.xrevrange(stream_key, count=1)
|
||||
if messages:
|
||||
msg_id = messages[0][0]
|
||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get last message ID: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Found active task: task_id={task_id}, "
|
||||
f"last_message_id={last_id}"
|
||||
)
|
||||
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=task_session_id,
|
||||
user_id=task_user_id,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status="running",
|
||||
),
|
||||
last_id,
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] No active task found for session {session_id} "
|
||||
f"(checked {tasks_checked} tasks)"
|
||||
)
|
||||
return None, "0-0"
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
|
||||
try:
|
||||
if chunk_type == ResponseType.START.value:
|
||||
return StreamStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.FINISH.value:
|
||||
return StreamFinish(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_START.value:
|
||||
return StreamTextStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_DELTA.value:
|
||||
return StreamTextDelta(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_END.value:
|
||||
return StreamTextEnd(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
|
||||
return StreamToolInputStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
|
||||
return StreamToolInputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
|
||||
return StreamToolOutputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.ERROR.value:
|
||||
return StreamError(**chunk_data)
|
||||
elif chunk_type == ResponseType.USAGE.value:
|
||||
return StreamUsage(**chunk_data)
|
||||
elif chunk_type == ResponseType.HEARTBEAT.value:
|
||||
return StreamHeartbeat(**chunk_data)
|
||||
else:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Track the asyncio.Task for a task (local reference only).
|
||||
|
||||
This is just for cleanup purposes - the task state is in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to track
|
||||
"""
|
||||
_local_tasks[task_id] = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Clean up when a subscriber disconnects.
|
||||
|
||||
With Redis-based pub/sub, there's no explicit unsubscription needed.
|
||||
The pub/sub listener task will be garbage collected when the subscriber
|
||||
stops reading from the queue.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
subscriber_queue: The subscriber's queue (unused, kept for API compat)
|
||||
"""
|
||||
# No-op - pub/sub listener cleans up automatically
|
||||
logger.debug(f"[SSE-RECONNECT] Subscriber disconnected from task {task_id}")
|
||||
@@ -549,15 +549,19 @@ async def decompose_goal(
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -565,8 +569,13 @@ async def generate_agent(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
@@ -806,6 +815,8 @@ async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -818,10 +829,12 @@ async def generate_agent_patch(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
@@ -829,5 +842,9 @@ async def generate_agent_patch(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, _to_dict_list(library_agents)
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
@@ -213,24 +213,45 @@ async def decompose_goal_external(
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -262,6 +283,8 @@ async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
@@ -269,21 +292,40 @@ async def generate_agent_patch_external(
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool):
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
|
||||
@@ -17,6 +17,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -104,6 +105,10 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -149,7 +154,11 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request, current_agent, library_agents
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -169,6 +178,20 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
|
||||
@@ -352,11 +352,15 @@ class OperationStartedResponse(ToolResponseBase):
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
@@ -380,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
|
||||
will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
@@ -40,6 +40,10 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for RabbitMQ notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from backend.api.features.integrations.router import router as integrations_router
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import utils as webhooks_utils
|
||||
|
||||
|
||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(integrations_router, prefix="/api/integrations")
|
||||
|
||||
provider = ProviderName.GITHUB
|
||||
webhook_id = "webhook_123"
|
||||
base_url = "https://example.com"
|
||||
|
||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
||||
|
||||
route = next(
|
||||
route
|
||||
for route in integrations_router.routes
|
||||
if isinstance(route, APIRoute)
|
||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
||||
and "POST" in route.methods
|
||||
)
|
||||
expected_path = f"/api/integrations{route.path}".format(
|
||||
provider=provider.value,
|
||||
webhook_id=webhook_id,
|
||||
)
|
||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
||||
expected_base = urlparse(base_url)
|
||||
|
||||
assert (actual_url.scheme, actual_url.netloc) == (
|
||||
expected_base.scheme,
|
||||
expected_base.netloc,
|
||||
)
|
||||
assert actual_url.path == expected_path
|
||||
@@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
@@ -70,41 +69,16 @@ export function useCopilotShell() {
|
||||
});
|
||||
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
@@ -114,7 +88,12 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function startNewChat() {
|
||||
function handleNewChatClick() {
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -123,32 +102,6 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -4,53 +4,25 @@ import { create } from "zustand";
|
||||
|
||||
interface CopilotStoreState {
|
||||
isStreaming: boolean;
|
||||
isSwitchingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
isInterruptModalOpen: boolean;
|
||||
pendingAction: (() => void) | null;
|
||||
}
|
||||
|
||||
interface CopilotStoreActions {
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||
setIsCreatingSession: (isCreating: boolean) => void;
|
||||
openInterruptModal: (onConfirm: () => void) => void;
|
||||
confirmInterrupt: () => void;
|
||||
cancelInterrupt: () => void;
|
||||
}
|
||||
|
||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||
export const useCopilotStore = create<CopilotStore>((set) => ({
|
||||
isStreaming: false,
|
||||
isSwitchingSession: false,
|
||||
isCreatingSession: false,
|
||||
isInterruptModalOpen: false,
|
||||
pendingAction: null,
|
||||
|
||||
setIsStreaming(isStreaming) {
|
||||
set({ isStreaming });
|
||||
},
|
||||
|
||||
setIsSwitchingSession(isSwitchingSession) {
|
||||
set({ isSwitchingSession });
|
||||
},
|
||||
|
||||
setIsCreatingSession(isCreatingSession) {
|
||||
set({ isCreatingSession });
|
||||
},
|
||||
|
||||
openInterruptModal(onConfirm) {
|
||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||
},
|
||||
|
||||
confirmInterrupt() {
|
||||
const { pendingAction } = get();
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
if (pendingAction) pendingAction();
|
||||
},
|
||||
|
||||
cancelInterrupt() {
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -5,15 +5,10 @@ import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
@@ -40,42 +35,6 @@ export default function CopilotPage() {
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
<Dialog
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isInterruptModalOpen,
|
||||
set: (open) => {
|
||||
if (!open) cancelInterrupt();
|
||||
},
|
||||
}}
|
||||
onClose={cancelInterrupt}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to continue?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={cancelInterrupt}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={confirmInterrupt}
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
@@ -26,20 +25,12 @@ export function useCopilotPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
/**
|
||||
* SSE Proxy for task stream reconnection.
|
||||
*
|
||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
||||
* background task's stream. It replays missed messages from Redis Streams and
|
||||
* subscribes to live updates if the task is still running.
|
||||
*
|
||||
* Client contract:
|
||||
* 1. When receiving an operation_started event, store the task_id
|
||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
* 3. Messages are replayed from the last_message_id position
|
||||
* 4. Stream ends when "finish" event is received
|
||||
*/
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ taskId: string }> },
|
||||
) {
|
||||
const { taskId } = await params;
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||
|
||||
try {
|
||||
// Get auth token from server-side session
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
// Build backend URL
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||
|
||||
// Forward request to backend with auth header
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(streamUrl.toString(), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return new Response(error, {
|
||||
status: response.status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
// Return the SSE stream directly
|
||||
return new Response(response.body, {
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
Connection: "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Task stream proxy error:", error);
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: "Failed to connect to task stream",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
}),
|
||||
{
|
||||
status: 500,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -939,6 +939,63 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/operations/{operation_id}/complete": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Complete Operation",
|
||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||
"operationId": "postV2CompleteOperation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "operation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Operation Id" }
|
||||
},
|
||||
{
|
||||
"name": "x-api-key",
|
||||
"in": "header",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "X-Api-Key"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Postv2Completeoperation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -1022,7 +1079,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1157,7 +1214,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1195,6 +1252,94 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Task Status",
|
||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2GetTaskStatus",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Getv2Gettaskstatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Task",
|
||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2StreamTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
},
|
||||
{
|
||||
"name": "last_message_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
"default": "0-0",
|
||||
"title": "Last Message Id"
|
||||
},
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -6168,6 +6313,16 @@
|
||||
"title": "AccuracyTrendsResponse",
|
||||
"description": "Response model for accuracy trends and alerts."
|
||||
},
|
||||
"ActiveStreamInfo": {
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "title": "Task Id" },
|
||||
"last_message_id": { "type": "string", "title": "Last Message Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["task_id", "last_message_id"],
|
||||
"title": "ActiveStreamInfo",
|
||||
"description": "Information about an active stream for reconnection."
|
||||
},
|
||||
"AddUserCreditsResponse": {
|
||||
"properties": {
|
||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||
@@ -8823,6 +8978,27 @@
|
||||
],
|
||||
"title": "OnboardingStep"
|
||||
},
|
||||
"OperationCompleteRequest": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"result": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Result"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "OperationCompleteRequest",
|
||||
"description": "Request model for external completion webhook."
|
||||
},
|
||||
"Pagination": {
|
||||
"properties": {
|
||||
"total_items": {
|
||||
@@ -9678,6 +9854,12 @@
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
"title": "Messages"
|
||||
},
|
||||
"active_stream": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -25,8 +24,8 @@ export function Chat({
|
||||
}: ChatProps) {
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||
const {
|
||||
session,
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
@@ -38,6 +37,21 @@ export function Chat({
|
||||
startPollingForOperation,
|
||||
} = useChat({ urlSessionId });
|
||||
|
||||
// Extract active stream info for reconnection
|
||||
const activeStream = (session as { active_stream?: { task_id: string; last_message_id: string } })?.active_stream;
|
||||
|
||||
// Debug logging for SSE reconnection
|
||||
if (session) {
|
||||
console.info("[SSE-RECONNECT] Session loaded:", {
|
||||
sessionId,
|
||||
hasActiveStream: !!activeStream,
|
||||
activeStream: activeStream ? {
|
||||
taskId: activeStream.task_id,
|
||||
lastMessageId: activeStream.last_message_id,
|
||||
} : null,
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!onSessionNotFound) return;
|
||||
if (!urlSessionId) return;
|
||||
@@ -53,8 +67,7 @@ export function Chat({
|
||||
isCreating,
|
||||
]);
|
||||
|
||||
const shouldShowLoader =
|
||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
@@ -66,21 +79,19 @@ export function Chat({
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
{isSwitchingSession
|
||||
? "Switching chat..."
|
||||
: "Loading your chat..."}
|
||||
Loading your chat...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && !isSwitchingSession && (
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
@@ -88,6 +99,10 @@ export function Chat({
|
||||
className="flex-1"
|
||||
onStreamingChange={onStreamingChange}
|
||||
onOperationStarted={startPollingForOperation}
|
||||
activeStream={activeStream ? {
|
||||
taskId: activeStream.task_id,
|
||||
lastMessageId: activeStream.last_message_id,
|
||||
} : undefined}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
# SSE Reconnection Contract for Long-Running Operations
|
||||
|
||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
||||
|
||||
## Overview
|
||||
|
||||
When a user triggers a long-running operation (like agent generation), the backend:
|
||||
|
||||
1. Spawns a background task that survives SSE disconnections
|
||||
2. Returns an `operation_started` response with a `task_id`
|
||||
3. Stores stream messages in Redis Streams for replay
|
||||
|
||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
||||
|
||||
## Client-Side Flow
|
||||
|
||||
### 1. Receiving Operation Started
|
||||
|
||||
When you receive an `operation_started` tool response:
|
||||
|
||||
```typescript
|
||||
// The response includes a task_id for reconnection
|
||||
{
|
||||
type: "operation_started",
|
||||
tool_name: "generate_agent",
|
||||
operation_id: "uuid-...",
|
||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
||||
message: "Operation started. You can close this tab."
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Storing Task Info
|
||||
|
||||
Use the chat store to track the active task:
|
||||
|
||||
```typescript
|
||||
import { useChatStore } from "./chat-store";
|
||||
|
||||
// When operation_started is received:
|
||||
useChatStore.getState().setActiveTask(sessionId, {
|
||||
taskId: response.task_id,
|
||||
operationId: response.operation_id,
|
||||
toolName: response.tool_name,
|
||||
lastMessageId: "0",
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Reconnecting to a Task
|
||||
|
||||
To reconnect (e.g., after page refresh or tab reopen):
|
||||
|
||||
```typescript
|
||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
||||
|
||||
// Check if there's an active task for this session
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
|
||||
if (activeTask) {
|
||||
// Reconnect to the task stream
|
||||
await reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId, // Resume from last position
|
||||
(chunk) => {
|
||||
// Handle incoming chunks
|
||||
console.log("Received chunk:", chunk);
|
||||
},
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tracking Message Position
|
||||
|
||||
To enable precise replay, update the last message ID as chunks arrive:
|
||||
|
||||
```typescript
|
||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
||||
|
||||
function handleChunk(chunk: StreamChunk) {
|
||||
// If chunk has an index/id, track it
|
||||
if (chunk.idx !== undefined) {
|
||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Task Stream Reconnection
|
||||
|
||||
```
|
||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
```
|
||||
|
||||
- `taskId`: The task ID from `operation_started`
|
||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
||||
|
||||
Returns: SSE stream of missed messages + live updates
|
||||
|
||||
## Chunk Types
|
||||
|
||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
||||
|
||||
| Type | Description |
|
||||
| ----------------------- | ----------------------- |
|
||||
| `start` | Message lifecycle start |
|
||||
| `text-delta` | Streaming text content |
|
||||
| `text-end` | Text block completed |
|
||||
| `tool-output-available` | Tool result available |
|
||||
| `finish` | Stream completed |
|
||||
| `error` | Error occurred |
|
||||
|
||||
## Error Handling
|
||||
|
||||
If reconnection fails:
|
||||
|
||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
||||
2. Fall back to polling the session for final state
|
||||
3. Show appropriate UI message to user
|
||||
|
||||
## Persistence Considerations
|
||||
|
||||
For robust reconnection across browser restarts:
|
||||
|
||||
```typescript
|
||||
// Store in localStorage/sessionStorage
|
||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
||||
|
||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
tasks[sessionId] = task;
|
||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
||||
}
|
||||
|
||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
}
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
The following backend settings affect reconnection behavior:
|
||||
|
||||
| Setting | Default | Description |
|
||||
| ------------------- | ------- | ---------------------------------- |
|
||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
||||
| `stream_max_length` | 1000 | Max messages per stream |
|
||||
|
||||
## Testing
|
||||
|
||||
To test reconnection locally:
|
||||
|
||||
1. Start a long-running operation (e.g., agent generation)
|
||||
2. Note the `task_id` from the `operation_started` response
|
||||
3. Close the browser tab
|
||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
||||
5. Verify that missed messages are replayed
|
||||
|
||||
See the main README for full local development setup.
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Constants for the chat system.
|
||||
*
|
||||
* Centralizes magic strings and values used across chat components.
|
||||
*/
|
||||
|
||||
// LocalStorage keys
|
||||
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
||||
|
||||
// Redis Stream IDs
|
||||
export const INITIAL_MESSAGE_ID = "0";
|
||||
export const INITIAL_STREAM_ID = "0-0";
|
||||
|
||||
// TTL values (in milliseconds)
|
||||
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
||||
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
||||
@@ -1,6 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { create } from "zustand";
|
||||
import {
|
||||
ACTIVE_TASK_TTL_MS,
|
||||
COMPLETED_STREAM_TTL_MS,
|
||||
INITIAL_STREAM_ID,
|
||||
STORAGE_KEY_ACTIVE_TASKS,
|
||||
} from "./chat-constants";
|
||||
import type {
|
||||
ActiveStream,
|
||||
StreamChunk,
|
||||
@@ -8,15 +14,64 @@ import type {
|
||||
StreamResult,
|
||||
StreamStatus,
|
||||
} from "./chat-types";
|
||||
import { executeStream } from "./stream-executor";
|
||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||
|
||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
/**
|
||||
* Tracks active task info for SSE reconnection.
|
||||
* When a long-running operation starts, we store this so clients can reconnect
|
||||
* if the browser tab is closed and reopened.
|
||||
*/
|
||||
export interface ActiveTaskInfo {
|
||||
taskId: string;
|
||||
sessionId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
lastMessageId: string; // Last processed message ID for replay (Redis Stream format: "0-0")
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
/** Load active tasks from localStorage */
|
||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||
if (typeof window === "undefined") return new Map();
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
||||
if (!stored) return new Map();
|
||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||
const now = Date.now();
|
||||
const tasks = new Map<string, ActiveTaskInfo>();
|
||||
// Filter out expired tasks
|
||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
||||
tasks.set(sessionId, task);
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
} catch {
|
||||
return new Map();
|
||||
}
|
||||
}
|
||||
|
||||
/** Save active tasks to localStorage */
|
||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
const obj: Record<string, ActiveTaskInfo> = {};
|
||||
for (const [sessionId, task] of tasks) {
|
||||
obj[sessionId] = task;
|
||||
}
|
||||
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
interface ChatStoreState {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
activeSessions: Set<string>;
|
||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
||||
activeTasks: Map<string, ActiveTaskInfo>;
|
||||
}
|
||||
|
||||
interface ChatStoreActions {
|
||||
@@ -41,6 +96,24 @@ interface ChatStoreActions {
|
||||
unregisterActiveSession: (sessionId: string) => void;
|
||||
isSessionActive: (sessionId: string) => boolean;
|
||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||
/** Track active task for SSE reconnection */
|
||||
setActiveTask: (
|
||||
sessionId: string,
|
||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
||||
) => void;
|
||||
/** Get active task for a session */
|
||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
||||
/** Clear active task when operation completes */
|
||||
clearActiveTask: (sessionId: string) => void;
|
||||
/** Reconnect to an existing task stream */
|
||||
reconnectToTask: (
|
||||
sessionId: string,
|
||||
taskId: string,
|
||||
lastMessageId?: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => Promise<void>;
|
||||
/** Update last message ID for a task (for tracking replay position) */
|
||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
||||
}
|
||||
|
||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||
@@ -64,18 +137,79 @@ function cleanupExpiredStreams(
|
||||
const now = Date.now();
|
||||
const cleaned = new Map(completedStreams);
|
||||
for (const [sessionId, result] of cleaned) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
||||
cleaned.delete(sessionId);
|
||||
}
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up an existing stream for a session and move it to completed streams.
|
||||
* Returns updated maps for both active and completed streams.
|
||||
*/
|
||||
function cleanupExistingStream(
|
||||
sessionId: string,
|
||||
activeStreams: Map<string, ActiveStream>,
|
||||
completedStreams: Map<string, StreamResult>,
|
||||
callbacks: Set<StreamCompleteCallback>,
|
||||
): {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
} {
|
||||
const newActiveStreams = new Map(activeStreams);
|
||||
let newCompletedStreams = new Map(completedStreams);
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming" ? "completed" : existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
return { activeStreams: newActiveStreams, completedStreams: newCompletedStreams };
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new active stream with initial state.
|
||||
*/
|
||||
function createActiveStream(
|
||||
sessionId: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
): ActiveStream {
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
}
|
||||
|
||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
activeStreams: new Map(),
|
||||
completedStreams: new Map(),
|
||||
activeSessions: new Set(),
|
||||
streamCompleteCallbacks: new Set(),
|
||||
activeTasks: loadPersistedTasks(),
|
||||
|
||||
startStream: async function startStream(
|
||||
sessionId,
|
||||
@@ -85,45 +219,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
// Clean up any existing stream for this session
|
||||
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
|
||||
cleanupExistingStream(
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
|
||||
// Create new stream
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
@@ -286,4 +394,134 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||
};
|
||||
},
|
||||
|
||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
||||
const state = get();
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...taskInfo,
|
||||
sessionId,
|
||||
startedAt: Date.now(),
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
getActiveTask: function getActiveTask(sessionId) {
|
||||
return get().activeTasks.get(sessionId);
|
||||
},
|
||||
|
||||
clearActiveTask: function clearActiveTask(sessionId) {
|
||||
const state = get();
|
||||
if (!state.activeTasks.has(sessionId)) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
reconnectToTask: async function reconnectToTask(
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId = INITIAL_STREAM_ID,
|
||||
onChunk,
|
||||
) {
|
||||
console.info("[SSE-RECONNECT] reconnectToTask called:", {
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId,
|
||||
});
|
||||
|
||||
const state = get();
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
|
||||
cleanupExistingStream(
|
||||
sessionId,
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
|
||||
// Create new stream for reconnection
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
console.info("[SSE-RECONNECT] Starting executeTaskReconnect...");
|
||||
try {
|
||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||
console.info("[SSE-RECONNECT] executeTaskReconnect completed:", {
|
||||
sessionId,
|
||||
taskId,
|
||||
streamStatus: stream.status,
|
||||
});
|
||||
} finally {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(
|
||||
currentState.streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
// Clear active task on completion
|
||||
const taskState = get();
|
||||
const newActiveTasks = new Map(taskState.activeTasks);
|
||||
const hadActiveTask = newActiveTasks.has(sessionId);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
if (hadActiveTask) {
|
||||
console.info(
|
||||
`[ChatStore] Cleared active task for session ${sessionId} ` +
|
||||
`(stream status: ${stream.status})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
||||
sessionId,
|
||||
lastMessageId,
|
||||
) {
|
||||
const state = get();
|
||||
const task = state.activeTasks.get(sessionId);
|
||||
if (!task) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...task,
|
||||
lastMessageId,
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "stream_start"
|
||||
| "text_chunk"
|
||||
| "text_ended"
|
||||
| "tool_call"
|
||||
@@ -15,6 +16,7 @@ export interface StreamChunk {
|
||||
| "error"
|
||||
| "usage"
|
||||
| "stream_end";
|
||||
taskId?: string; // Task ID for SSE reconnection
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
@@ -41,7 +43,7 @@ export interface StreamChunk {
|
||||
}
|
||||
|
||||
export type VercelStreamChunk =
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "start"; messageId: string; taskId?: string }
|
||||
| { type: "finish" }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
@@ -92,3 +94,67 @@ export interface StreamResult {
|
||||
}
|
||||
|
||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||
|
||||
// Type guards for message types
|
||||
|
||||
/**
|
||||
* Check if a message has a toolId property.
|
||||
*/
|
||||
export function hasToolId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolId: string } {
|
||||
return "toolId" in msg && typeof (msg as Record<string, unknown>).toolId === "string";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has an operationId property.
|
||||
*/
|
||||
export function hasOperationId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { operationId: string } {
|
||||
return (
|
||||
"operationId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).operationId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has a toolCallId property.
|
||||
*/
|
||||
export function hasToolCallId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolCallId: string } {
|
||||
return (
|
||||
"toolCallId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message is an operation message type.
|
||||
*/
|
||||
export function isOperationMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & {
|
||||
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
||||
} {
|
||||
return (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tool ID from a message if available.
|
||||
* Checks toolId, operationId, and toolCallId properties.
|
||||
*/
|
||||
export function getToolIdFromMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): string | undefined {
|
||||
const record = msg as Record<string, unknown>;
|
||||
if (typeof record.toolId === "string") return record.toolId;
|
||||
if (typeof record.operationId === "string") return record.operationId;
|
||||
if (typeof record.toolCallId === "string") return record.toolCallId;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ export interface ChatContainerProps {
|
||||
className?: string;
|
||||
onStreamingChange?: (isStreaming: boolean) => void;
|
||||
onOperationStarted?: () => void;
|
||||
/** Active stream info from the server for reconnection */
|
||||
activeStream?: {
|
||||
taskId: string;
|
||||
lastMessageId: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function ChatContainer({
|
||||
@@ -26,6 +31,7 @@ export function ChatContainer({
|
||||
className,
|
||||
onStreamingChange,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
}: ChatContainerProps) {
|
||||
const {
|
||||
messages,
|
||||
@@ -41,6 +47,7 @@ export function ChatContainer({
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -34,6 +34,22 @@ export function createStreamEventDispatcher(
|
||||
}
|
||||
|
||||
switch (chunk.type) {
|
||||
case "stream_start":
|
||||
// Store task ID for SSE reconnection
|
||||
if (chunk.taskId && deps.onActiveTaskStarted) {
|
||||
console.info("[ChatStream] Stream started with task ID:", {
|
||||
sessionId: deps.sessionId,
|
||||
taskId: chunk.taskId,
|
||||
});
|
||||
deps.onActiveTaskStarted({
|
||||
taskId: chunk.taskId,
|
||||
operationId: chunk.taskId, // Use taskId as operationId for chat streams
|
||||
toolName: "chat",
|
||||
toolCallId: "chat_stream",
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "text_chunk":
|
||||
handleTextChunk(chunk, deps);
|
||||
break;
|
||||
@@ -56,7 +72,8 @@ export function createStreamEventDispatcher(
|
||||
break;
|
||||
|
||||
case "stream_end":
|
||||
console.info("[ChatStream] Stream ended:", {
|
||||
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
|
||||
console.info("[SSE-RECONNECT] Stream ended:", {
|
||||
sessionId: deps.sessionId,
|
||||
hasResponse: deps.hasResponseRef.current,
|
||||
chunkCount: deps.streamingChunksRef.current.length,
|
||||
|
||||
@@ -23,6 +23,12 @@ export interface HandlerDependencies {
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
onOperationStarted?: () => void;
|
||||
onActiveTaskStarted?: (taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
@@ -164,9 +170,19 @@ export function handleToolResponse(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Trigger polling when operation_started is received
|
||||
// Trigger polling and store task info when operation_started is received
|
||||
if (responseMessage.type === "operation_started") {
|
||||
deps.onOperationStarted?.();
|
||||
// Store task info for SSE reconnection if taskId is present
|
||||
const taskId = (responseMessage as any).taskId;
|
||||
if (taskId && deps.onActiveTaskStarted) {
|
||||
deps.onActiveTaskStarted({
|
||||
taskId,
|
||||
operationId: (responseMessage as any).operationId || "",
|
||||
toolName: (responseMessage as any).toolName || "",
|
||||
toolCallId: (responseMessage as any).toolId || "",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
deps.setMessages((prev) => {
|
||||
@@ -205,8 +221,10 @@ export function handleStreamEnd(
|
||||
_chunk: StreamChunk,
|
||||
deps: HandlerDependencies,
|
||||
) {
|
||||
console.info("[SSE-RECONNECT] handleStreamEnd called, resetting streaming state");
|
||||
const completedContent = deps.streamingChunksRef.current.join("");
|
||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||
console.info("[SSE-RECONNECT] No content received, adding placeholder message");
|
||||
deps.setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
@@ -245,10 +263,14 @@ export function handleStreamEnd(
|
||||
|
||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||
console.error("Stream error:", errorMessage);
|
||||
console.error("[ChatStream] Stream error:", errorMessage, {
|
||||
sessionId: deps.sessionId,
|
||||
chunk,
|
||||
});
|
||||
if (isRegionBlockedError(chunk)) {
|
||||
deps.setIsRegionBlockedModalOpen(true);
|
||||
}
|
||||
console.info("[ChatStream] Resetting streaming state due to error");
|
||||
deps.setIsStreamingInitiated(false);
|
||||
deps.setHasTextChunks(false);
|
||||
deps.setStreamingChunks([]);
|
||||
|
||||
@@ -349,6 +349,7 @@ export function parseToolResponse(
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolId,
|
||||
operationId: (parsedResult.operation_id as string) || "",
|
||||
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation started. You can close this tab.",
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { INITIAL_STREAM_ID } from "../../chat-constants";
|
||||
import { useChatStore } from "../../chat-store";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import {
|
||||
getToolIdFromMessage,
|
||||
hasToolId,
|
||||
isOperationMessage,
|
||||
} from "../../chat-types";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
@@ -14,6 +20,46 @@ import {
|
||||
processInitialMessages,
|
||||
} from "./helpers";
|
||||
|
||||
/**
|
||||
* Dependencies for creating a stream event dispatcher.
|
||||
* Extracted to allow helper function creation.
|
||||
*/
|
||||
interface DispatcherDeps {
|
||||
setHasTextChunks: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setStreamingChunks: React.Dispatch<React.SetStateAction<string[]>>;
|
||||
streamingChunksRef: React.MutableRefObject<string[]>;
|
||||
hasResponseRef: React.MutableRefObject<boolean>;
|
||||
setMessages: React.Dispatch<React.SetStateAction<ChatMessageData[]>>;
|
||||
setIsRegionBlockedModalOpen: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
setIsStreamingInitiated: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
onOperationStarted?: () => void;
|
||||
onActiveTaskStarted: (taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a stream event dispatcher with the given dependencies.
|
||||
*/
|
||||
function createDispatcher(deps: DispatcherDeps) {
|
||||
return createStreamEventDispatcher({
|
||||
setHasTextChunks: deps.setHasTextChunks,
|
||||
setStreamingChunks: deps.setStreamingChunks,
|
||||
streamingChunksRef: deps.streamingChunksRef,
|
||||
hasResponseRef: deps.hasResponseRef,
|
||||
setMessages: deps.setMessages,
|
||||
setIsRegionBlockedModalOpen: deps.setIsRegionBlockedModalOpen,
|
||||
sessionId: deps.sessionId,
|
||||
setIsStreamingInitiated: deps.setIsStreamingInitiated,
|
||||
onOperationStarted: deps.onOperationStarted,
|
||||
onActiveTaskStarted: deps.onActiveTaskStarted,
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to generate deduplication key for a message
|
||||
function getMessageKey(msg: ChatMessageData): string {
|
||||
if (msg.type === "message") {
|
||||
@@ -24,13 +70,11 @@ function getMessageKey(msg: ChatMessageData): string {
|
||||
} else if (msg.type === "tool_call") {
|
||||
return `toolcall:${msg.toolId}`;
|
||||
} else if (msg.type === "tool_response") {
|
||||
return `toolresponse:${(msg as any).toolId}`;
|
||||
} else if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||
const toolId = hasToolId(msg) ? msg.toolId : "";
|
||||
return `toolresponse:${toolId}`;
|
||||
} else if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg) || "";
|
||||
return `op:${toolId}:${msg.toolName}`;
|
||||
} else {
|
||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||
}
|
||||
@@ -41,6 +85,11 @@ interface Args {
|
||||
initialMessages: SessionDetailResponse["messages"];
|
||||
initialPrompt?: string;
|
||||
onOperationStarted?: () => void;
|
||||
/** Active stream info from the server for reconnection */
|
||||
activeStream?: {
|
||||
taskId: string;
|
||||
lastMessageId: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function useChatContainer({
|
||||
@@ -48,6 +97,7 @@ export function useChatContainer({
|
||||
initialMessages,
|
||||
initialPrompt,
|
||||
onOperationStarted,
|
||||
activeStream,
|
||||
}: Args) {
|
||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||
@@ -65,30 +115,195 @@ export function useChatContainer({
|
||||
} = useChatStream();
|
||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
// Track whether we've already connected to this activeStream to avoid duplicate connections
|
||||
const connectedActiveStreamRef = useRef<string | null>(null);
|
||||
|
||||
// Callback to store active task info for SSE reconnection
|
||||
function handleActiveTaskStarted(taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) {
|
||||
if (!sessionId) return;
|
||||
setActiveTask(sessionId, {
|
||||
taskId: taskInfo.taskId,
|
||||
operationId: taskInfo.operationId,
|
||||
toolName: taskInfo.toolName,
|
||||
lastMessageId: INITIAL_STREAM_ID,
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(
|
||||
function handleSessionChange() {
|
||||
if (sessionId === previousSessionIdRef.current) return;
|
||||
const isSessionChange = sessionId !== previousSessionIdRef.current;
|
||||
|
||||
const prevSession = previousSessionIdRef.current;
|
||||
if (prevSession) {
|
||||
stopStreaming(prevSession);
|
||||
console.info("[SSE-RECONNECT] handleSessionChange effect running:", {
|
||||
sessionId,
|
||||
previousSessionId: previousSessionIdRef.current,
|
||||
isSessionChange,
|
||||
hasActiveStream: !!activeStream,
|
||||
activeStreamTaskId: activeStream?.taskId,
|
||||
connectedActiveStream: connectedActiveStreamRef.current,
|
||||
});
|
||||
|
||||
// Handle session change - reset state
|
||||
if (isSessionChange) {
|
||||
console.info("[SSE-RECONNECT] Session changed, resetting state");
|
||||
const prevSession = previousSessionIdRef.current;
|
||||
if (prevSession) {
|
||||
stopStreaming(prevSession);
|
||||
}
|
||||
previousSessionIdRef.current = sessionId;
|
||||
connectedActiveStreamRef.current = null; // Reset connected stream tracker
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
}
|
||||
previousSessionIdRef.current = sessionId;
|
||||
setMessages([]);
|
||||
setStreamingChunks([]);
|
||||
streamingChunksRef.current = [];
|
||||
setHasTextChunks(false);
|
||||
setIsStreamingInitiated(false);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
if (!sessionId) return;
|
||||
if (!sessionId) {
|
||||
console.info("[SSE-RECONNECT] No sessionId, skipping reconnection check");
|
||||
return;
|
||||
}
|
||||
|
||||
const activeStream = activeStreams.get(sessionId);
|
||||
if (!activeStream || activeStream.status !== "streaming") return;
|
||||
// Priority 1: Check if server told us there's an active stream (most authoritative)
|
||||
// Also handles the case where activeStream arrives after initial session load
|
||||
if (activeStream) {
|
||||
// Skip if we've already connected to this exact stream
|
||||
// Check and set immediately to prevent race conditions from effect re-runs
|
||||
const streamKey = `${sessionId}:${activeStream.taskId}`;
|
||||
if (connectedActiveStreamRef.current === streamKey) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Already connected to this stream, skipping:",
|
||||
{ streamKey },
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
// Also skip if there's already an active stream for this session in the store
|
||||
// (handles case where effect re-runs due to activeStreams state change)
|
||||
const existingStream = activeStreams.get(sessionId);
|
||||
if (existingStream && existingStream.status === "streaming") {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Active stream already exists in store, skipping:",
|
||||
{ sessionId, status: existingStream.status },
|
||||
);
|
||||
connectedActiveStreamRef.current = streamKey;
|
||||
return;
|
||||
}
|
||||
|
||||
// Set immediately after check to prevent race conditions
|
||||
connectedActiveStreamRef.current = streamKey;
|
||||
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Server reports active stream, initiating reconnection:",
|
||||
{
|
||||
sessionId,
|
||||
taskId: activeStream.taskId,
|
||||
lastMessageId: activeStream.lastMessageId,
|
||||
streamKey,
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Store this as the active task for future reconnects
|
||||
setActiveTask(sessionId, {
|
||||
taskId: activeStream.taskId,
|
||||
operationId: activeStream.taskId,
|
||||
toolName: "chat",
|
||||
lastMessageId: activeStream.lastMessageId,
|
||||
});
|
||||
// Reconnect to the task stream
|
||||
console.info("[SSE-RECONNECT] Calling reconnectToTask...");
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeStream.taskId,
|
||||
activeStream.lastMessageId,
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check localStorage/in-memory on session change, not on every render
|
||||
if (!isSessionChange) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] No active stream and not a session change, skipping fallbacks",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Priority 2: Check localStorage for active task (client-side state)
|
||||
console.info("[SSE-RECONNECT] Checking localStorage for active task...");
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
if (activeTask) {
|
||||
console.info(
|
||||
"[SSE-RECONNECT] Found active task in localStorage, attempting reconnect:",
|
||||
{
|
||||
sessionId,
|
||||
taskId: activeTask.taskId,
|
||||
lastMessageId: activeTask.lastMessageId,
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Reconnect to the task stream
|
||||
console.info("[SSE-RECONNECT] Calling reconnectToTask from localStorage...");
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId,
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
} else {
|
||||
console.info("[SSE-RECONNECT] No active task in localStorage");
|
||||
}
|
||||
|
||||
// Priority 3: Check for an in-memory active stream (same-tab scenario)
|
||||
console.info("[SSE-RECONNECT] Checking in-memory active streams...");
|
||||
const inMemoryStream = activeStreams.get(sessionId);
|
||||
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
|
||||
console.info("[SSE-RECONNECT] No in-memory active stream found:", {
|
||||
hasStream: !!inMemoryStream,
|
||||
status: inMemoryStream?.status,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
@@ -98,6 +313,7 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
@@ -110,6 +326,10 @@ export function useChatContainer({
|
||||
activeStreams,
|
||||
subscribeToStream,
|
||||
onOperationStarted,
|
||||
getActiveTask,
|
||||
reconnectToTask,
|
||||
activeStream,
|
||||
setActiveTask,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -124,7 +344,7 @@ export function useChatContainer({
|
||||
msg.type === "agent_carousel" ||
|
||||
msg.type === "execution_started"
|
||||
) {
|
||||
const toolId = (msg as any).toolId;
|
||||
const toolId = hasToolId(msg) ? msg.toolId : undefined;
|
||||
if (toolId) {
|
||||
ids.add(toolId);
|
||||
}
|
||||
@@ -141,12 +361,8 @@ export function useChatContainer({
|
||||
|
||||
setMessages((prev) => {
|
||||
const filtered = prev.filter((msg) => {
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg);
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false; // Remove - operation completed
|
||||
}
|
||||
@@ -174,12 +390,8 @@ export function useChatContainer({
|
||||
// Filter local messages: remove duplicates and completed operation messages
|
||||
const newLocalMessages = messages.filter((msg) => {
|
||||
// Remove operation messages for completed tools
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg);
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false;
|
||||
}
|
||||
@@ -215,7 +427,7 @@ export function useChatContainer({
|
||||
setIsStreamingInitiated(true);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
@@ -225,6 +437,7 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
try {
|
||||
|
||||
@@ -111,6 +111,7 @@ export type ChatMessageData =
|
||||
toolName: string;
|
||||
toolId: string;
|
||||
operationId: string;
|
||||
taskId?: string; // For SSE reconnection
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { INITIAL_MESSAGE_ID } from "./chat-constants";
|
||||
import type {
|
||||
ActiveStream,
|
||||
StreamChunk,
|
||||
@@ -10,8 +11,14 @@ import {
|
||||
parseSSELine,
|
||||
} from "./stream-utils";
|
||||
|
||||
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||
stream.chunks.push(chunk);
|
||||
function notifySubscribers(
|
||||
stream: ActiveStream,
|
||||
chunk: StreamChunk,
|
||||
skipStore = false,
|
||||
) {
|
||||
if (!skipStore) {
|
||||
stream.chunks.push(chunk);
|
||||
}
|
||||
for (const callback of stream.onChunkCallbacks) {
|
||||
try {
|
||||
callback(chunk);
|
||||
@@ -21,42 +28,133 @@ function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function executeStream(
|
||||
stream: ActiveStream,
|
||||
message: string,
|
||||
isUserMessage: boolean,
|
||||
context?: { url: string; content: string },
|
||||
retryCount: number = 0,
|
||||
/**
|
||||
* Options for stream execution.
|
||||
*/
|
||||
interface StreamExecutionOptions {
|
||||
/** The active stream state object */
|
||||
stream: ActiveStream;
|
||||
/** Execution mode: 'new' for new stream, 'reconnect' for task reconnection */
|
||||
mode: "new" | "reconnect";
|
||||
/** Message content (required for 'new' mode) */
|
||||
message?: string;
|
||||
/** Whether this is a user message (for 'new' mode) */
|
||||
isUserMessage?: boolean;
|
||||
/** Optional context for the message (for 'new' mode) */
|
||||
context?: { url: string; content: string };
|
||||
/** Task ID (required for 'reconnect' mode) */
|
||||
taskId?: string;
|
||||
/** Last message ID for replay (for 'reconnect' mode) */
|
||||
lastMessageId?: string;
|
||||
/** Current retry count (internal use) */
|
||||
retryCount?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unified stream execution function that handles both new streams and task reconnection.
|
||||
*
|
||||
* For new streams:
|
||||
* - Posts a message to create a new chat stream
|
||||
* - Reads SSE chunks and notifies subscribers
|
||||
*
|
||||
* For reconnection:
|
||||
* - Connects to an existing task stream
|
||||
* - Replays messages from lastMessageId position
|
||||
* - Allows resumption of long-running operations
|
||||
*/
|
||||
async function executeStreamInternal(
|
||||
options: StreamExecutionOptions,
|
||||
): Promise<void> {
|
||||
const {
|
||||
stream,
|
||||
mode,
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
taskId,
|
||||
lastMessageId = INITIAL_MESSAGE_ID,
|
||||
retryCount = 0,
|
||||
} = options;
|
||||
|
||||
const { sessionId, abortController } = stream;
|
||||
const isReconnect = mode === "reconnect";
|
||||
const logPrefix = isReconnect ? "[SSE-RECONNECT]" : "[StreamExecutor]";
|
||||
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} executeStream starting:`, {
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const url = `/api/chat/sessions/${sessionId}/stream`;
|
||||
const body = JSON.stringify({
|
||||
message,
|
||||
is_user_message: isUserMessage,
|
||||
context: context || null,
|
||||
});
|
||||
// Build URL and request options based on mode
|
||||
let url: string;
|
||||
let fetchOptions: RequestInit;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (isReconnect) {
|
||||
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
fetchOptions = {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
signal: abortController.signal,
|
||||
};
|
||||
console.info(`${logPrefix} Fetching task stream:`, { url });
|
||||
} else {
|
||||
url = `/api/chat/sessions/${sessionId}/stream`;
|
||||
fetchOptions = {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message,
|
||||
is_user_message: isUserMessage,
|
||||
context: context || null,
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
};
|
||||
}
|
||||
|
||||
const response = await fetch(url, fetchOptions);
|
||||
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream response:`, {
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
});
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
if (isReconnect) {
|
||||
console.error(`${logPrefix} Task stream error response:`, {
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
}
|
||||
// For reconnect: don't retry on 404/403 (permanent errors)
|
||||
const isPermanentError =
|
||||
isReconnect && (response.status === 404 || response.status === 403);
|
||||
const error = new Error(errorText || `HTTP ${response.status}`);
|
||||
(error as Error & { status?: number }).status = response.status;
|
||||
(error as Error & { isPermanent?: boolean }).isPermanent =
|
||||
isPermanentError;
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream connected, reading chunks...`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
@@ -65,6 +163,11 @@ export async function executeStream(
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
if (isReconnect) {
|
||||
console.info(
|
||||
`${logPrefix} Task stream reader done (connection closed)`,
|
||||
);
|
||||
}
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -78,6 +181,9 @@ export async function executeStream(
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream received [DONE] signal`);
|
||||
}
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -90,14 +196,30 @@ export async function executeStream(
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
// Log first few chunks for debugging (reconnect mode only)
|
||||
if (isReconnect && stream.chunks.length < 3) {
|
||||
console.info(`${logPrefix} Task stream chunk received:`, {
|
||||
type: chunk.type,
|
||||
chunkIndex: stream.chunks.length,
|
||||
});
|
||||
}
|
||||
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
if (isReconnect) {
|
||||
console.info(
|
||||
`${logPrefix} Task stream completed via stream_end chunk`,
|
||||
);
|
||||
}
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
if (isReconnect) {
|
||||
console.error(`${logPrefix} Task stream error chunk:`, chunk);
|
||||
}
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
@@ -105,7 +227,7 @@ export async function executeStream(
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
|
||||
console.warn(`${logPrefix} Failed to parse SSE chunk:`, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -117,18 +239,27 @@ export async function executeStream(
|
||||
return;
|
||||
}
|
||||
|
||||
if (retryCount < MAX_RETRIES) {
|
||||
// Check if this is a permanent error (404/403) that shouldn't be retried
|
||||
const isPermanentError =
|
||||
err instanceof Error &&
|
||||
(err as Error & { isPermanent?: boolean }).isPermanent;
|
||||
|
||||
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
`${logPrefix} Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeStream(
|
||||
stream,
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
retryCount + 1,
|
||||
return executeStreamInternal({
|
||||
...options,
|
||||
retryCount: retryCount + 1,
|
||||
});
|
||||
}
|
||||
|
||||
// Log permanent errors differently for debugging
|
||||
if (isPermanentError) {
|
||||
console.log(
|
||||
`${logPrefix} Stream failed permanently (task not found or access denied): ${(err as Error).message}`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -140,3 +271,52 @@ export async function executeStream(
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a new chat stream.
|
||||
*
|
||||
* Posts a message to create a new stream and reads SSE responses.
|
||||
*/
|
||||
export async function executeStream(
|
||||
stream: ActiveStream,
|
||||
message: string,
|
||||
isUserMessage: boolean,
|
||||
context?: { url: string; content: string },
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
return executeStreamInternal({
|
||||
stream,
|
||||
mode: "new",
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an existing task stream.
|
||||
*
|
||||
* This is used when a client wants to resume receiving updates from a
|
||||
* long-running background task. Messages are replayed from the last_message_id
|
||||
* position, allowing clients to catch up on missed events.
|
||||
*
|
||||
* @param stream - The active stream state
|
||||
* @param taskId - The task ID to reconnect to
|
||||
* @param lastMessageId - The last message ID received (for replay)
|
||||
* @param retryCount - Current retry count
|
||||
*/
|
||||
export async function executeTaskReconnect(
|
||||
stream: ActiveStream,
|
||||
taskId: string,
|
||||
lastMessageId: string = INITIAL_MESSAGE_ID,
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
return executeStreamInternal({
|
||||
stream,
|
||||
mode: "reconnect",
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -28,7 +28,8 @@ export function normalizeStreamChunk(
|
||||
|
||||
switch (chunk.type) {
|
||||
case "text-delta":
|
||||
return { type: "text_chunk", content: chunk.delta };
|
||||
// Backend sends "content", Vercel AI SDK sends "delta"
|
||||
return { type: "text_chunk", content: chunk.delta || chunk.content };
|
||||
case "text-end":
|
||||
return { type: "text_ended" };
|
||||
case "tool-input-available":
|
||||
@@ -63,6 +64,10 @@ export function normalizeStreamChunk(
|
||||
case "finish":
|
||||
return { type: "stream_end" };
|
||||
case "start":
|
||||
// Start event with optional taskId for reconnection
|
||||
return chunk.taskId
|
||||
? { type: "stream_start", taskId: chunk.taskId }
|
||||
: null;
|
||||
case "text-start":
|
||||
return null;
|
||||
case "tool-input-start":
|
||||
|
||||
Reference in New Issue
Block a user