Compare commits

...

10 Commits

Author SHA1 Message Date
Swifty
d1da7fe5da remove call for onbaording step 2026-01-30 12:15:35 +01:00
Swifty
11e27cfdcf Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-30 12:01:45 +01:00
Swifty
0be5fedc86 updating sse reconection logic be 2026-01-30 11:58:42 +01:00
Swifty
f2e81648b5 updating SSE reconnection logic 2026-01-30 11:58:25 +01:00
Swifty
bb608ea60d pr comments 2026-01-29 22:29:17 +01:00
Swifty
46af3b94f2 Merge branch 'swiftyos/sse-long-running-tasks' of github.com:Significant-Gravitas/AutoGPT into swiftyos/sse-long-running-tasks 2026-01-29 18:03:01 +01:00
Swifty
083cceca0f fixing edge cases 2026-01-29 18:02:21 +01:00
Swifty
06758adefd Merge branch 'dev' into swiftyos/sse-long-running-tasks 2026-01-29 13:33:32 +01:00
Swifty
c01c29a059 fmt issues 2026-01-29 13:28:01 +01:00
Swifty
d738059da8 added long running task support 2026-01-29 10:24:14 +01:00
25 changed files with 2538 additions and 181 deletions

View File

@@ -0,0 +1,325 @@
"""RabbitMQ consumer for operation completion messages.
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
"""
import asyncio
import logging
import orjson
from pydantic import BaseModel
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Queue and exchange configuration
OPERATION_COMPLETE_EXCHANGE = Exchange(
name="chat_operations",
type=ExchangeType.DIRECT,
durable=True,
)
OPERATION_COMPLETE_QUEUE = Queue(
name="chat_operation_complete",
durable=True,
exchange=OPERATION_COMPLETE_EXCHANGE,
routing_key="operation.complete",
)
RABBITMQ_CONFIG = RabbitMQConfig(
exchanges=[OPERATION_COMPLETE_EXCHANGE],
queues=[OPERATION_COMPLETE_QUEUE],
)
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from RabbitMQ."""
def __init__(self):
self._rabbitmq: AsyncRabbitMQ | None = None
self._consumer_task: asyncio.Task | None = None
self._running = False
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
await self._rabbitmq.connect()
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info("Chat completion consumer started")
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._rabbitmq:
await self._rabbitmq.disconnect()
self._rabbitmq = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
while self._running and retry_count < max_retries:
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
# Reset retry count on successful connection
retry_count = 0
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
return
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _handle_message(self, body: bytes) -> None:
"""Handle a single completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
# Try to look up by task_id directly
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
# Publish result to stream registry
result_output = message.result if message.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
result_str = (
message.result
if isinstance(message.result, str)
else (
orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
)
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Successfully processed completion for task {task.task_id} "
f"(operation {message.operation_id})"
)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
error_response = ErrorResponse(
message=error_msg,
error=message.error,
)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Processed failure for task {task.task_id} "
f"(operation {message.operation_id}): {error_msg}"
)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message.
This is a helper function for testing or for services that want to
publish completion messages directly.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
try:
await rabbitmq.connect()
await rabbitmq.publish_message(
routing_key="operation.complete",
message=message.model_dump_json(),
exchange=OPERATION_COMPLETE_EXCHANGE,
)
logger.info(f"Published completion for operation {operation_id}")
finally:
await rabbitmq.disconnect()

View File

@@ -44,6 +44,20 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=1000,
description="Maximum number of messages to store per stream",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field(
@@ -82,6 +96,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -4,16 +4,19 @@ import logging
from collections.abc import AsyncGenerator
from typing import Annotated
import orjson
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
config = ChatConfig()
@@ -81,6 +84,14 @@ class ListSessionsResponse(BaseModel):
total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -366,6 +377,267 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise NotFoundError(f"Task {task_id} not found or access denied.")
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
chunk_count = 0
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
chunk_count += 1
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
logger.info(
f"Task stream completed for task {task_id}, "
f"chunk_count={chunk_count}"
)
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
# Publish result to stream registry
from .response_model import StreamToolOutputAvailable
result_output = request.result if request.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
from . import service as svc
result_str = (
request.result
if isinstance(request.result, str)
else (
orjson.dumps(request.result).decode("utf-8")
if request.result
else '{"status": "completed"}'
)
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await svc._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await svc._mark_operation_completed(task.tool_call_id)
else:
# Publish error to stream registry
from .response_model import StreamError
error_msg = request.error or "Operation failed"
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Send finish event to end the stream
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
from . import service as svc
from .tools.models import ErrorResponse
error_response = ErrorResponse(
message=error_msg,
error=request.error,
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await svc._mark_operation_completed(task.tool_call_id)
return {"status": "ok", "task_id": task.task_id}
# ========== Health Check ==========

View File

@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
from backend.util.settings import Settings
from . import db as chat_db
from . import stream_registry
from .config import ChatConfig
from .model import (
ChatMessage,
@@ -1610,8 +1611,9 @@ async def _yield_tool_call(
)
return
# Generate operation ID
# Generate operation ID and task ID
operation_id = str(uuid_module.uuid4())
task_id = str(uuid_module.uuid4())
# Build a user-friendly message based on tool and arguments
if tool_name == "create_agent":
@@ -1654,6 +1656,16 @@ async def _yield_tool_call(
# Wrap session save and task creation in try-except to release lock on failure
try:
# Create task in stream registry for SSE reconnection support
await stream_registry.create_task(
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Save assistant message with tool_call FIRST (required by LLM)
assistant_message = ChatMessage(
role="assistant",
@@ -1675,23 +1687,27 @@ async def _yield_tool_call(
session.messages.append(pending_message)
await upsert_chat_session(session)
logger.info(
f"Saved pending operation {operation_id} for tool {tool_name} "
f"in session {session.session_id}"
f"Saved pending operation {operation_id} (task_id={task_id}) "
f"for tool {tool_name} in session {session.session_id}"
)
# Store task reference in module-level set to prevent GC before completion
task = asyncio.create_task(
_execute_long_running_tool(
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=arguments,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session.session_id,
user_id=session.user_id,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
# Associate the asyncio task with the stream registry task
await stream_registry.set_task_asyncio_task(task_id, bg_task)
except Exception as e:
# Roll back appended messages to prevent data corruption on subsequent saves
if (
@@ -1709,6 +1725,11 @@ async def _yield_tool_call(
# Release the Redis lock since the background task won't be spawned
await _mark_operation_completed(tool_call_id)
# Mark stream registry task as failed if it was created
try:
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception:
pass
logger.error(
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
)
@@ -1722,6 +1743,7 @@ async def _yield_tool_call(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id, # Include task_id for SSE reconnection
).model_dump_json(),
success=True,
)
@@ -1791,6 +1813,9 @@ async def _execute_long_running_tool(
This function runs independently of the SSE connection, so the operation
survives if the user closes their browser tab.
NOTE: This is the legacy function without stream registry support.
Use _execute_long_running_tool_with_streaming for new implementations.
"""
try:
# Load fresh session (not stale reference)
@@ -1838,6 +1863,128 @@ async def _execute_long_running_tool(
await _mark_operation_completed(tool_call_id)
async def _execute_long_running_tool_with_streaming(
tool_name: str,
parameters: dict[str, Any],
tool_call_id: str,
operation_id: str,
task_id: str,
session_id: str,
user_id: str | None,
) -> None:
"""Execute a long-running tool with stream registry support for SSE reconnection.
This function runs independently of the SSE connection, publishes progress
to the stream registry, and survives if the user closes their browser tab.
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
If the external service returns a 202 Accepted (async), this function exits
early and lets the RabbitMQ completion consumer handle the rest.
"""
# Track whether we delegated to async processing - if so, the RabbitMQ
# completion consumer will handle cleanup, not us
delegated_to_async = False
try:
# Load fresh session (not stale reference)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for background tool")
await stream_registry.mark_task_completed(task_id, status="failed")
return
# Pass operation_id and task_id to the tool for async processing
enriched_parameters = {
**parameters,
"_operation_id": operation_id,
"_task_id": task_id,
}
# Execute the actual tool
result = await execute_tool(
tool_name=tool_name,
parameters=enriched_parameters,
tool_call_id=tool_call_id,
user_id=user_id,
session=session,
)
# Check if the tool result indicates async processing
# (e.g., Agent Generator returned 202 Accepted)
try:
result_data = orjson.loads(result.output) if result.output else {}
if result_data.get("status") == "accepted":
logger.info(
f"Tool {tool_name} delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id}). "
f"RabbitMQ completion consumer will handle the rest."
)
# Don't publish result, don't continue with LLM, and don't cleanup
# The RabbitMQ consumer will handle everything when the external
# service completes and publishes to the queue
delegated_to_async = True
return
except (orjson.JSONDecodeError, TypeError):
pass # Not JSON or not async - continue normally
# Publish tool result to stream registry
await stream_registry.publish_chunk(task_id, result)
# Update the pending message with result
result_str = (
result.output
if isinstance(result.output, str)
else orjson.dumps(result.output).decode("utf-8")
)
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=result_str,
)
logger.info(
f"Background tool {tool_name} completed for session {session_id} "
f"(task_id={task_id})"
)
# Generate LLM continuation and stream chunks to registry
await _generate_llm_continuation_with_streaming(
session_id=session_id,
user_id=user_id,
task_id=task_id,
)
# Mark task as completed in stream registry
await stream_registry.mark_task_completed(task_id, status="completed")
except Exception as e:
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
error_response = ErrorResponse(
message=f"Tool {tool_name} failed: {str(e)}",
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed in stream registry
await stream_registry.mark_task_completed(task_id, status="failed")
finally:
# Only cleanup if we didn't delegate to async processing
# For async path, the RabbitMQ completion consumer handles cleanup
if not delegated_to_async:
await _mark_operation_completed(tool_call_id)
async def _update_pending_operation(
session_id: str,
tool_call_id: str,
@@ -1964,3 +2111,128 @@ async def _generate_llm_continuation(
except Exception as e:
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
async def _generate_llm_continuation_with_streaming(
session_id: str,
user_id: str | None,
task_id: str,
) -> None:
"""Generate an LLM response with streaming to the stream registry.
This is called by background tasks to continue the conversation
after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them.
"""
import uuid as uuid_module
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
session = await get_chat_session(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for LLM continuation")
return
# Build system prompt
system_prompt, _ = await _build_system_prompt(user_id)
# Build messages in OpenAI format
messages = session.to_openai_messages()
if system_prompt:
from openai.types.chat import ChatCompletionSystemMessageParam
system_message = ChatCompletionSystemMessageParam(
role="system",
content=system_prompt,
)
messages = [system_message] + messages
# Build extra_body for tracing
extra_body: dict[str, Any] = {
"posthogProperties": {
"environment": settings.config.app_env.value,
},
}
if user_id:
extra_body["user"] = user_id[:128]
extra_body["posthogDistinctId"] = user_id
if session_id:
extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Publish start event
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
# Stream the response
stream = await client.chat.completions.create(
model=config.model,
messages=cast(list[ChatCompletionMessageParam], messages),
extra_body=extra_body,
stream=True,
)
assistant_content = ""
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
delta = chunk.choices[0].delta.content
assistant_content += delta
# Publish delta to stream registry
await stream_registry.publish_chunk(
task_id,
StreamTextDelta(id=text_block_id, delta=delta),
)
# Publish end events
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
if assistant_content:
# Reload session from DB to avoid race condition with user messages
fresh_session = await get_chat_session(session_id, user_id)
if not fresh_session:
logger.error(
f"Session {session_id} disappeared during LLM continuation"
)
return
# Save assistant message to database
assistant_message = ChatMessage(
role="assistant",
content=assistant_content,
)
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
logger.info(
f"Generated streaming LLM continuation for session {session_id} "
f"(task_id={task_id}), response length: {len(assistant_content)}"
)
else:
logger.warning(
f"Streaming LLM continuation returned empty response for {session_id}"
)
except Exception as e:
logger.error(
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
)
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -0,0 +1,648 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It supports:
- Creating tasks with unique IDs for long-running operations
- Publishing stream messages to both Redis Streams and in-memory queues
- Subscribing to tasks with replay of missed messages
- Looking up tasks by operation_id for webhook callbacks
- Cross-pod real-time delivery via Redis pub/sub
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track active pub/sub listeners for cross-pod delivery
_pubsub_listeners: dict[str, asyncio.Task] = {}
@dataclass
class ActiveTask:
"""Represents an active streaming task."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
asyncio_task: asyncio.Task | None = None
# Lock for atomic status checks and subscriber management
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
# Set of subscriber queues for fan-out
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
# Module-level registry for active tasks
_active_tasks: dict[str, ActiveTask] = {}
# Redis key patterns
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{TASK_META_PREFIX}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{TASK_STREAM_PREFIX}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{TASK_OP_PREFIX}{operation_id}"
def _get_task_pubsub_channel(task_id: str) -> str:
"""Get Redis pub/sub channel for task cross-pod delivery."""
return f"{TASK_PUBSUB_PREFIX}{task_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in memory and Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store in memory registry
_active_tasks[task_id] = task
# Store metadata in Redis for durability
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.info(
f"Created streaming task {task_id} for operation {operation_id} "
f"in session {session_id}"
)
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to the task's stream.
Delivers to in-memory subscribers first (for real-time), then persists to
Redis Stream (for replay). This order ensures live subscribers get messages
even if Redis temporarily fails.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
Redis persistence failed
"""
# Deliver to in-memory subscribers FIRST for real-time updates
task = _active_tasks.get(task_id)
if task:
async with task.lock:
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(chunk)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id}, dropping chunk"
)
# Then persist to Redis Stream for replay (with error handling)
message_id = "0-0"
chunk_json = chunk.model_dump_json()
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Add to Redis Stream with auto-generated ID
# The ID format is "timestamp-sequence" which gives us ordering
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Publish to pub/sub for cross-pod real-time delivery
pubsub_channel = _get_task_pubsub_channel(task_id)
await redis.publish(pubsub_channel, chunk_json)
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
except Exception as e:
logger.error(
f"Failed to persist chunk to Redis for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
# Check in-memory first
task = _active_tasks.get(task_id)
if task:
# Validate ownership
if user_id and task.user_id and task.user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task.user_id}"
)
return None
# Create a new queue for this subscriber
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
# Replay from Redis Stream
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Track the last message ID we've seen for gap detection
replay_last_id = last_message_id
# Read all messages from stream starting after last_message_id
# xread returns messages with ID > last_message_id
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
if messages:
# messages format: [[stream_name, [(id, {data: json}), ...]]]
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
# Track the last message ID we've processed
replay_last_id = (
msg_id if isinstance(msg_id, str) else msg_id.decode()
)
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
# Reconstruct the appropriate response type
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Atomically check status and register subscriber under lock
# This prevents race condition where task completes between check and subscribe
should_start_pubsub = False
async with task.lock:
if task.status == "running":
# Register this subscriber for live updates
task.subscribers.add(subscriber_queue)
# Start pub/sub listener if this is the first subscriber
should_start_pubsub = len(task.subscribers) == 1
logger.debug(
f"Registered subscriber for task {task_id}, "
f"total subscribers: {len(task.subscribers)}"
)
else:
# Task is done, add finish marker
await subscriber_queue.put(StreamFinish())
# After registering, do a second read to catch any messages published
# between the first read and registration (closes the race window)
if task.status == "running":
gap_messages = await redis.xread(
{stream_key: replay_last_id}, block=0, count=1000
)
if gap_messages:
for _stream_name, stream_messages in gap_messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay gap message: {e}")
# Start pub/sub listener outside the lock to avoid deadlocks
if should_start_pubsub:
await start_pubsub_listener(task_id)
return subscriber_queue
# Try to load from Redis if not in memory
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.warning(f"Task {task_id} not found in memory or Redis")
return None
# Validate ownership
task_user_id = meta.get(b"user_id", b"").decode() or None
if user_id and task_user_id and task_user_id != user_id:
logger.warning(
f"User {user_id} attempted to subscribe to task {task_id} "
f"owned by {task_user_id}"
)
return None
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
subscriber_queue = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Read all messages starting after last_message_id
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
if messages:
for _stream_name, stream_messages in messages:
for _msg_id, msg_data in stream_messages:
if b"data" in msg_data:
try:
chunk_data = orjson.loads(msg_data[b"data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
# Add finish marker since task is not active
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> None:
"""Mark a task as completed and publish final event.
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
"""
task = _active_tasks.get(task_id)
if task:
# Acquire lock to prevent new subscribers during completion
async with task.lock:
task.status = status
# Send finish event directly to all current subscribers
finish_event = StreamFinish()
for subscriber_queue in task.subscribers:
try:
subscriber_queue.put_nowait(finish_event)
except asyncio.QueueFull:
logger.warning(
f"Subscriber queue full for task {task_id} during completion"
)
# Clear subscribers since task is done
task.subscribers.clear()
# Stop pub/sub listener since task is done
await stop_pubsub_listener(task_id)
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
await publish_chunk(task_id, StreamFinish())
# Remove from active tasks after a short delay to allow subscribers to finish
async def _cleanup():
await asyncio.sleep(5)
_active_tasks.pop(task_id, None)
logger.info(f"Cleaned up task {task_id} from memory")
asyncio.create_task(_cleanup())
# Update Redis metadata
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
await redis.hset(meta_key, "status", status) # type: ignore[misc]
logger.info(f"Marked task {task_id} as {status}")
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
for task in _active_tasks.values():
if task.operation_id == operation_id:
return task
# Try Redis lookup
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if task_id:
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
# Check if task is in memory
if task_id_str in _active_tasks:
return _active_tasks[task_id_str]
# Load metadata from Redis
meta_key = _get_task_meta_key(task_id_str)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
# Reconstruct task object (not fully active, but has metadata)
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=operation_id,
status=meta.get(b"status", b"running").decode(), # type: ignore
)
return None
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
# Check in-memory first
if task_id in _active_tasks:
return _active_tasks[task_id]
# Try Redis lookup
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if meta:
return ActiveTask(
task_id=meta.get(b"task_id", b"").decode(),
session_id=meta.get(b"session_id", b"").decode(),
user_id=meta.get(b"user_id", b"").decode() or None,
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
tool_name=meta.get(b"tool_name", b"").decode(),
operation_id=meta.get(b"operation_id", b"").decode(),
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
)
return None
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
chunk_type = chunk_data.get("type")
try:
if chunk_type == ResponseType.START.value:
return StreamStart(**chunk_data)
elif chunk_type == ResponseType.FINISH.value:
return StreamFinish(**chunk_data)
elif chunk_type == ResponseType.TEXT_START.value:
return StreamTextStart(**chunk_data)
elif chunk_type == ResponseType.TEXT_DELTA.value:
return StreamTextDelta(**chunk_data)
elif chunk_type == ResponseType.TEXT_END.value:
return StreamTextEnd(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
return StreamToolInputStart(**chunk_data)
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
return StreamToolInputAvailable(**chunk_data)
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
return StreamToolOutputAvailable(**chunk_data)
elif chunk_type == ResponseType.ERROR.value:
return StreamError(**chunk_data)
elif chunk_type == ResponseType.USAGE.value:
return StreamUsage(**chunk_data)
elif chunk_type == ResponseType.HEARTBEAT.value:
return StreamHeartbeat(**chunk_data)
else:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Associate an asyncio.Task with an ActiveTask.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to associate
"""
task = _active_tasks.get(task_id)
if task:
task.asyncio_task = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Unsubscribe a queue from a task's stream.
Should be called when a client disconnects to clean up resources.
Also stops the pub/sub listener if there are no more local subscribers.
Args:
task_id: Task ID to unsubscribe from
subscriber_queue: The queue to remove from subscribers
"""
task = _active_tasks.get(task_id)
if task:
async with task.lock:
task.subscribers.discard(subscriber_queue)
remaining = len(task.subscribers)
logger.debug(
f"Unsubscribed from task {task_id}, "
f"remaining subscribers: {remaining}"
)
# Stop pub/sub listener if no more local subscribers
if remaining == 0:
await stop_pubsub_listener(task_id)
async def start_pubsub_listener(task_id: str) -> None:
"""Start listening to Redis pub/sub for cross-pod delivery.
This enables real-time updates when another pod publishes chunks for a task
that has local subscribers on this pod.
Args:
task_id: Task ID to listen for
"""
if task_id in _pubsub_listeners:
return # Already listening
task = _active_tasks.get(task_id)
if not task:
return
async def _listener():
try:
redis = await get_redis_async()
pubsub = redis.pubsub()
channel = _get_task_pubsub_channel(task_id)
await pubsub.subscribe(channel)
logger.debug(f"Started pub/sub listener for task {task_id}")
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
chunk_data = orjson.loads(message["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
# Deliver to local subscribers
local_task = _active_tasks.get(task_id)
if local_task:
async with local_task.lock:
for queue in local_task.subscribers:
try:
queue.put_nowait(chunk)
except asyncio.QueueFull:
pass
# Stop listening if this was a finish event
if isinstance(chunk, StreamFinish):
break
except Exception as e:
logger.warning(f"Error processing pub/sub message: {e}")
await pubsub.unsubscribe(channel)
await pubsub.close()
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
finally:
_pubsub_listeners.pop(task_id, None)
logger.debug(f"Stopped pub/sub listener for task {task_id}")
listener_task = asyncio.create_task(_listener())
_pubsub_listeners[task_id] = listener_task
async def stop_pubsub_listener(task_id: str) -> None:
"""Stop the pub/sub listener for a task.
Args:
task_id: Task ID to stop listening for
"""
listener = _pubsub_listeners.pop(task_id, None)
if listener and not listener.done():
listener.cancel()
try:
await listener
except asyncio.CancelledError:
pass
logger.debug(f"Cancelled pub/sub listener for task {task_id}")

View File

@@ -57,21 +57,32 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
return await decompose_goal_external(description, context)
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
async def generate_agent(
instructions: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(instructions)
result = await generate_agent_external(instructions, operation_id, task_id)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
# Check if it's an error response - pass through as-is
if isinstance(result, dict) and result.get("type") == "error":
@@ -256,7 +267,10 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -268,14 +282,18 @@ async def generate_agent_patch(
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(update_request, current_agent)
return await generate_agent_patch_external(
update_request, current_agent, operation_id, task_id
)

View File

@@ -207,21 +207,42 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Agent JSON dict on success, or error dict {"type": "error", ...} on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions}
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/generate-agent", json={"instructions": instructions}
)
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -251,27 +272,48 @@ async def generate_agent_external(
async def generate_agent_patch_external(
update_request: str, current_agent: dict[str, Any]
update_request: str,
current_agent: dict[str, Any],
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
Args:
update_request: Natural language description of changes
current_agent: Current agent JSON
operation_id: Operation ID for async processing (enables RabbitMQ callback)
task_id: Task ID for async processing (enables RabbitMQ callback)
Returns:
Updated agent JSON, clarifying questions dict, or error dict on error
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
client = _get_client()
# Build request payload
payload: dict[str, Any] = {
"update_request": update_request,
"current_agent_json": current_agent,
}
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post(
"/api/update-agent",
json={
"update_request": update_request,
"current_agent_json": current_agent,
},
)
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()

View File

@@ -16,6 +16,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -96,6 +97,10 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
@@ -192,7 +197,11 @@ class CreateAgentTool(BaseTool):
# Step 2: Generate agent JSON (external service handles fixing and validation)
try:
agent_json = await generate_agent(decomposition_result)
agent_json = await generate_agent(
decomposition_result,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -232,6 +241,19 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))

View File

@@ -16,6 +16,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -103,6 +104,10 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -134,7 +139,12 @@ class EditAgentTool(BaseTool):
# Step 2: Generate updated agent (external service handles fixing and validation)
try:
result = await generate_agent_patch(update_request, current_agent)
result = await generate_agent_patch(
update_request,
current_agent,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
@@ -153,6 +163,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -352,11 +352,15 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
@@ -380,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -40,6 +40,10 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for RabbitMQ notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:

View File

@@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
@@ -70,41 +69,16 @@ export function useCopilotShell() {
});
const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
const pendingActionRef = useRef<(() => void) | null>(null);
async function stopCurrentStream() {
if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId);
});
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId),
@@ -114,7 +88,12 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function startNewChat() {
function handleNewChatClick() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
resetPagination();
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
@@ -123,32 +102,6 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer();
}
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return {
isMobile,
isDrawerOpen,

View File

@@ -4,53 +4,25 @@ import { create } from "zustand";
interface CopilotStoreState {
isStreaming: boolean;
isSwitchingSession: boolean;
isCreatingSession: boolean;
isInterruptModalOpen: boolean;
pendingAction: (() => void) | null;
}
interface CopilotStoreActions {
setIsStreaming: (isStreaming: boolean) => void;
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
setIsCreatingSession: (isCreating: boolean) => void;
openInterruptModal: (onConfirm: () => void) => void;
confirmInterrupt: () => void;
cancelInterrupt: () => void;
}
type CopilotStore = CopilotStoreState & CopilotStoreActions;
export const useCopilotStore = create<CopilotStore>((set, get) => ({
export const useCopilotStore = create<CopilotStore>((set) => ({
isStreaming: false,
isSwitchingSession: false,
isCreatingSession: false,
isInterruptModalOpen: false,
pendingAction: null,
setIsStreaming(isStreaming) {
set({ isStreaming });
},
setIsSwitchingSession(isSwitchingSession) {
set({ isSwitchingSession });
},
setIsCreatingSession(isCreatingSession) {
set({ isCreatingSession });
},
openInterruptModal(onConfirm) {
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
},
confirmInterrupt() {
const { pendingAction } = get();
set({ isInterruptModalOpen: false, pendingAction: null });
if (pendingAction) pendingAction();
},
cancelInterrupt() {
set({ isInterruptModalOpen: false, pendingAction: null });
},
}));

View File

@@ -5,15 +5,10 @@ import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
import { Text } from "@/components/atoms/Text/Text";
import { Chat } from "@/components/contextual/Chat/Chat";
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
import { Dialog } from "@/components/molecules/Dialog/Dialog";
import { useCopilotStore } from "./copilot-page-store";
import { useCopilotPage } from "./useCopilotPage";
export default function CopilotPage() {
const { state, handlers } = useCopilotPage();
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const {
greetingName,
quickActions,
@@ -40,42 +35,6 @@ export default function CopilotPage() {
onSessionNotFound={handleSessionNotFound}
onStreamingChange={handleStreamingChange}
/>
<Dialog
title="Interrupt current chat?"
styling={{ maxWidth: 300, width: "100%" }}
controlled={{
isOpen: isInterruptModalOpen,
set: (open) => {
if (!open) cancelInterrupt();
},
}}
onClose={cancelInterrupt}
>
<Dialog.Content>
<div className="flex flex-col gap-4">
<Text variant="body">
The current chat response will be interrupted. Are you sure you
want to continue?
</Text>
<Dialog.Footer>
<Button
type="button"
variant="outline"
onClick={cancelInterrupt}
>
Cancel
</Button>
<Button
type="button"
variant="primary"
onClick={confirmInterrupt}
>
Continue
</Button>
</Dialog.Footer>
</div>
</Dialog.Content>
</Dialog>
</div>
);
}

View File

@@ -5,7 +5,6 @@ import {
import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import {
Flag,
type FlagValues,
@@ -26,20 +25,12 @@ export function useCopilotPage() {
const queryClient = useQueryClient();
const { user, isLoggedIn, isUserLoading } = useSupabase();
const { toast } = useToast();
const { completeStep } = useOnboarding();
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
useEffect(() => {
if (isLoggedIn) {
completeStep("VISIT_COPILOT");
}
}, [completeStep, isLoggedIn]);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
const homepageRoute = getHomepageRoute(isChatEnabled);

View File

@@ -0,0 +1,81 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
/**
* SSE Proxy for task stream reconnection.
*
* This endpoint allows clients to reconnect to an ongoing or recently completed
* background task's stream. It replays missed messages from Redis Streams and
* subscribes to live updates if the task is still running.
*
* Client contract:
* 1. When receiving an operation_started event, store the task_id
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
* 3. Messages are replayed from the last_message_id position
* 4. Stream ends when "finish" event is received
*/
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
// Get auth token from server-side session
const token = await getServerAuthToken();
// Build backend URL
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
// Forward request to backend with auth header
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
// Return the SSE stream directly
return new Response(response.body, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"X-Accel-Buffering": "no",
},
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -939,6 +939,63 @@
}
}
},
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions": {
"get": {
"tags": ["v2", "chat", "chat"],
@@ -1195,6 +1252,94 @@
}
}
},
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": {
"get": {
"tags": ["v1", "credits"],
@@ -8804,6 +8949,27 @@
],
"title": "OnboardingStep"
},
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"Pagination": {
"properties": {
"total_items": {

View File

@@ -1,7 +1,6 @@
"use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils";
@@ -25,7 +24,6 @@ export function Chat({
}: ChatProps) {
const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const {
messages,
isLoading,
@@ -53,8 +51,7 @@ export function Chat({
isCreating,
]);
const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
const shouldShowLoader = showLoader && (isLoading || isCreating);
return (
<div className={cn("flex h-full flex-col", className)}>
@@ -66,21 +63,19 @@ export function Chat({
<div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500">
{isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
Loading your chat...
</Text>
</div>
</div>
)}
{/* Error State */}
{error && !isLoading && !isSwitchingSession && (
{error && !isLoading && (
<ChatErrorState error={error} onRetry={createSession} />
)}
{/* Session Content */}
{sessionId && !isLoading && !error && !isSwitchingSession && (
{sessionId && !isLoading && !error && (
<ChatContainer
sessionId={sessionId}
initialMessages={messages}

View File

@@ -0,0 +1,156 @@
# SSE Reconnection Contract for Long-Running Operations
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
## Overview
When a user triggers a long-running operation (like agent generation), the backend:
1. Spawns a background task that survives SSE disconnections
2. Returns an `operation_started` response with a `task_id`
3. Stores stream messages in Redis Streams for replay
Clients can reconnect to the task stream at any time to receive missed messages.
## Client-Side Flow
### 1. Receiving Operation Started
When you receive an `operation_started` tool response:
```typescript
// The response includes a task_id for reconnection
{
type: "operation_started",
tool_name: "generate_agent",
operation_id: "uuid-...",
task_id: "task-uuid-...", // <-- Store this for reconnection
message: "Operation started. You can close this tab."
}
```
### 2. Storing Task Info
Use the chat store to track the active task:
```typescript
import { useChatStore } from "./chat-store";
// When operation_started is received:
useChatStore.getState().setActiveTask(sessionId, {
taskId: response.task_id,
operationId: response.operation_id,
toolName: response.tool_name,
lastMessageId: "0",
});
```
### 3. Reconnecting to a Task
To reconnect (e.g., after page refresh or tab reopen):
```typescript
const { reconnectToTask, getActiveTask } = useChatStore.getState();
// Check if there's an active task for this session
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Reconnect to the task stream
await reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId, // Resume from last position
(chunk) => {
// Handle incoming chunks
console.log("Received chunk:", chunk);
}
);
}
```
### 4. Tracking Message Position
To enable precise replay, update the last message ID as chunks arrive:
```typescript
const { updateTaskLastMessageId } = useChatStore.getState();
function handleChunk(chunk: StreamChunk) {
// If chunk has an index/id, track it
if (chunk.idx !== undefined) {
updateTaskLastMessageId(sessionId, String(chunk.idx));
}
}
```
## API Endpoints
### Task Stream Reconnection
```
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
```
- `taskId`: The task ID from `operation_started`
- `last_message_id`: Last received message index (default: "0" for full replay)
Returns: SSE stream of missed messages + live updates
## Chunk Types
The reconnected stream follows the same Vercel AI SDK protocol:
| Type | Description |
|------|-------------|
| `start` | Message lifecycle start |
| `text-delta` | Streaming text content |
| `text-end` | Text block completed |
| `tool-output-available` | Tool result available |
| `finish` | Stream completed |
| `error` | Error occurred |
## Error Handling
If reconnection fails:
1. Check if task still exists (may have expired - default TTL: 1 hour)
2. Fall back to polling the session for final state
3. Show appropriate UI message to user
## Persistence Considerations
For robust reconnection across browser restarts:
```typescript
// Store in localStorage/sessionStorage
const ACTIVE_TASKS_KEY = "chat_active_tasks";
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
tasks[sessionId] = task;
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
}
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
}
```
## Backend Configuration
The following backend settings affect reconnection behavior:
| Setting | Default | Description |
|---------|---------|-------------|
| `stream_ttl` | 3600s | How long streams are kept in Redis |
| `stream_max_length` | 1000 | Max messages per stream |
## Testing
To test reconnection locally:
1. Start a long-running operation (e.g., agent generation)
2. Note the `task_id` from the `operation_started` response
3. Close the browser tab
4. Reopen and call `reconnectToTask` with the saved `task_id`
5. Verify that missed messages are replayed
See the main README for full local development setup.

View File

@@ -8,15 +8,68 @@ import type {
StreamResult,
StreamStatus,
} from "./chat-types";
import { executeStream } from "./stream-executor";
import { executeStream, executeTaskReconnect } from "./stream-executor";
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
const ACTIVE_TASKS_STORAGE_KEY = "chat_active_tasks";
const TASK_TTL = 60 * 60 * 1000; // 1 hour - tasks expire after this
/**
* Tracks active task info for SSE reconnection.
* When a long-running operation starts, we store this so clients can reconnect
* if the browser tab is closed and reopened.
*/
export interface ActiveTaskInfo {
taskId: string;
sessionId: string;
operationId: string;
toolName: string;
lastMessageId: string; // Last processed message ID for replay (Redis Stream format: "0-0")
startedAt: number;
}
/** Load active tasks from localStorage */
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
if (typeof window === "undefined") return new Map();
try {
const stored = localStorage.getItem(ACTIVE_TASKS_STORAGE_KEY);
if (!stored) return new Map();
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
const now = Date.now();
const tasks = new Map<string, ActiveTaskInfo>();
// Filter out expired tasks
for (const [sessionId, task] of Object.entries(parsed)) {
if (now - task.startedAt < TASK_TTL) {
tasks.set(sessionId, task);
}
}
return tasks;
} catch {
return new Map();
}
}
/** Save active tasks to localStorage */
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
if (typeof window === "undefined") return;
try {
const obj: Record<string, ActiveTaskInfo> = {};
for (const [sessionId, task] of tasks) {
obj[sessionId] = task;
}
localStorage.setItem(ACTIVE_TASKS_STORAGE_KEY, JSON.stringify(obj));
} catch {
// Ignore storage errors
}
}
interface ChatStoreState {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
activeSessions: Set<string>;
streamCompleteCallbacks: Set<StreamCompleteCallback>;
/** Active tasks for SSE reconnection - keyed by sessionId */
activeTasks: Map<string, ActiveTaskInfo>;
}
interface ChatStoreActions {
@@ -41,6 +94,24 @@ interface ChatStoreActions {
unregisterActiveSession: (sessionId: string) => void;
isSessionActive: (sessionId: string) => boolean;
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
/** Track active task for SSE reconnection */
setActiveTask: (
sessionId: string,
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
) => void;
/** Get active task for a session */
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
/** Clear active task when operation completes */
clearActiveTask: (sessionId: string) => void;
/** Reconnect to an existing task stream */
reconnectToTask: (
sessionId: string,
taskId: string,
lastMessageId?: string,
onChunk?: (chunk: StreamChunk) => void,
) => Promise<void>;
/** Update last message ID for a task (for tracking replay position) */
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
}
type ChatStore = ChatStoreState & ChatStoreActions;
@@ -76,6 +147,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
completedStreams: new Map(),
activeSessions: new Set(),
streamCompleteCallbacks: new Set(),
activeTasks: loadPersistedTasks(),
startStream: async function startStream(
sessionId,
@@ -286,4 +358,139 @@ export const useChatStore = create<ChatStore>((set, get) => ({
set({ streamCompleteCallbacks: cleanedCallbacks });
};
},
setActiveTask: function setActiveTask(sessionId, taskInfo) {
const state = get();
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...taskInfo,
sessionId,
startedAt: Date.now(),
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
getActiveTask: function getActiveTask(sessionId) {
return get().activeTasks.get(sessionId);
},
clearActiveTask: function clearActiveTask(sessionId) {
const state = get();
if (!state.activeTasks.has(sessionId)) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
reconnectToTask: async function reconnectToTask(
sessionId,
taskId,
lastMessageId = "0-0", // Redis Stream ID format
onChunk,
) {
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
try {
await executeTaskReconnect(stream, taskId, lastMessageId);
} finally {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
// Clear active task on completion
const taskState = get();
const newActiveTasks = new Map(taskState.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
}
}
}
}
},
updateTaskLastMessageId: function updateTaskLastMessageId(
sessionId,
lastMessageId,
) {
const state = get();
const task = state.activeTasks.get(sessionId);
if (!task) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...task,
lastMessageId,
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
}));

View File

@@ -23,6 +23,12 @@ export interface HandlerDependencies {
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
sessionId: string;
onOperationStarted?: () => void;
onActiveTaskStarted?: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
}
export function isRegionBlockedError(chunk: StreamChunk): boolean {
@@ -164,9 +170,19 @@ export function handleToolResponse(
}
return;
}
// Trigger polling when operation_started is received
// Trigger polling and store task info when operation_started is received
if (responseMessage.type === "operation_started") {
deps.onOperationStarted?.();
// Store task info for SSE reconnection if taskId is present
const taskId = (responseMessage as any).taskId;
if (taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId,
operationId: (responseMessage as any).operationId || "",
toolName: (responseMessage as any).toolName || "",
toolCallId: (responseMessage as any).toolId || "",
});
}
}
deps.setMessages((prev) => {

View File

@@ -349,6 +349,7 @@ export function parseToolResponse(
toolName: (parsedResult.tool_name as string) || toolName,
toolId,
operationId: (parsedResult.operation_id as string) || "",
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
message:
(parsedResult.message as string) ||
"Operation started. You can close this tab.",

View File

@@ -65,8 +65,27 @@ export function useChatContainer({
} = useChatStream();
const activeStreams = useChatStore((s) => s.activeStreams);
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
const setActiveTask = useChatStore((s) => s.setActiveTask);
const getActiveTask = useChatStore((s) => s.getActiveTask);
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
const isStreaming = isStreamingInitiated || hasTextChunks;
// Callback to store active task info for SSE reconnection
function handleActiveTaskStarted(taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) {
if (!sessionId) return;
setActiveTask(sessionId, {
taskId: taskInfo.taskId,
operationId: taskInfo.operationId,
toolName: taskInfo.toolName,
lastMessageId: "0-0", // Redis Stream ID format for full replay
});
}
useEffect(
function handleSessionChange() {
if (sessionId === previousSessionIdRef.current) return;
@@ -85,6 +104,34 @@ export function useChatContainer({
if (!sessionId) return;
// Check if there's an active task for this session that we should reconnect to
const activeTask = getActiveTask(sessionId);
if (activeTask) {
const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
// Reconnect to the task stream
reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId,
dispatcher,
);
return;
}
// Otherwise check for an in-memory active stream
const activeStream = activeStreams.get(sessionId);
if (!activeStream || activeStream.status !== "streaming") return;
@@ -98,6 +145,7 @@ export function useChatContainer({
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
setIsStreamingInitiated(true);
@@ -110,6 +158,8 @@ export function useChatContainer({
activeStreams,
subscribeToStream,
onOperationStarted,
getActiveTask,
reconnectToTask,
],
);
@@ -225,6 +275,7 @@ export function useChatContainer({
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
try {

View File

@@ -111,6 +111,7 @@ export type ChatMessageData =
toolName: string;
toolId: string;
operationId: string;
taskId?: string; // For SSE reconnection
message: string;
timestamp?: string | Date;
}

View File

@@ -10,8 +10,14 @@ import {
parseSSELine,
} from "./stream-utils";
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
stream.chunks.push(chunk);
function notifySubscribers(
stream: ActiveStream,
chunk: StreamChunk,
skipStore = false,
) {
if (!skipStore) {
stream.chunks.push(chunk);
}
for (const callback of stream.onChunkCallbacks) {
try {
callback(chunk);
@@ -140,3 +146,124 @@ export async function executeStream(
});
}
}
/**
* Reconnect to an existing task stream.
*
* This is used when a client wants to resume receiving updates from a
* long-running background task. Messages are replayed from the last_message_id
* position, allowing clients to catch up on missed events.
*
* @param stream - The active stream state
* @param taskId - The task ID to reconnect to
* @param lastMessageId - The last message ID received (for replay)
* @param retryCount - Current retry count
*/
export async function executeTaskReconnect(
stream: ActiveStream,
taskId: string,
lastMessageId: string = "0",
retryCount: number = 0,
): Promise<void> {
const { abortController } = stream;
try {
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
const response = await fetch(url, {
method: "GET",
headers: {
Accept: "text/event-stream",
},
signal: abortController.signal,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(errorText || `HTTP ${response.status}`);
}
if (!response.body) {
throw new Error("Response body is null");
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
const data = parseSSELine(line);
if (data !== null) {
if (data === "[DONE]") {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
try {
const rawChunk = JSON.parse(data) as
| StreamChunk
| VercelStreamChunk;
const chunk = normalizeStreamChunk(rawChunk);
if (!chunk) continue;
notifySubscribers(stream, chunk);
if (chunk.type === "stream_end") {
stream.status = "completed";
return;
}
if (chunk.type === "error") {
stream.status = "error";
stream.error = new Error(
chunk.message || chunk.content || "Stream error",
);
return;
}
} catch (err) {
console.warn(
"[StreamExecutor] Failed to parse task reconnect SSE chunk:",
err,
);
}
}
}
}
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
if (retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1);
}
stream.status = "error";
stream.error = err instanceof Error ? err : new Error("Task reconnect failed");
notifySubscribers(stream, {
type: "error",
message: stream.error.message,
});
}
}