mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 09:28:19 -05:00
Compare commits
10 Commits
feat/sub-a
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1da7fe5da | ||
|
|
11e27cfdcf | ||
|
|
0be5fedc86 | ||
|
|
f2e81648b5 | ||
|
|
bb608ea60d | ||
|
|
46af3b94f2 | ||
|
|
083cceca0f | ||
|
|
06758adefd | ||
|
|
c01c29a059 | ||
|
|
d738059da8 |
@@ -0,0 +1,325 @@
|
||||
"""RabbitMQ consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer that listens for completion notifications
|
||||
from external services (like Agent Generator) and triggers the appropriate
|
||||
stream registry and chat service updates.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
)
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Queue and exchange configuration
|
||||
OPERATION_COMPLETE_EXCHANGE = Exchange(
|
||||
name="chat_operations",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
)
|
||||
|
||||
OPERATION_COMPLETE_QUEUE = Queue(
|
||||
name="chat_operation_complete",
|
||||
durable=True,
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
routing_key="operation.complete",
|
||||
)
|
||||
|
||||
RABBITMQ_CONFIG = RabbitMQConfig(
|
||||
exchanges=[OPERATION_COMPLETE_EXCHANGE],
|
||||
queues=[OPERATION_COMPLETE_QUEUE],
|
||||
)
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from RabbitMQ."""
|
||||
|
||||
def __init__(self):
|
||||
self._rabbitmq: AsyncRabbitMQ | None = None
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
await self._rabbitmq.connect()
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info("Chat completion consumer started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._rabbitmq:
|
||||
await self._rabbitmq.disconnect()
|
||||
self._rabbitmq = None
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
if not self._rabbitmq:
|
||||
logger.error("RabbitMQ not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
channel = await self._rabbitmq.get_channel()
|
||||
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
await self._handle_message(message.body)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be requeued due to exception
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a single completion message."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
# Try to look up by task_id directly
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
# Publish result to stream registry
|
||||
result_output = message.result if message.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
result_str = (
|
||||
message.result
|
||||
if isinstance(message.result, str)
|
||||
else (
|
||||
orjson.dumps(message.result).decode("utf-8")
|
||||
if message.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed completion for task {task.task_id} "
|
||||
f"(operation {message.operation_id})"
|
||||
)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
error_msg = message.error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=message.error,
|
||||
)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Processed failure for task {task.task_id} "
|
||||
f"(operation {message.operation_id}): {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message.
|
||||
|
||||
This is a helper function for testing or for services that want to
|
||||
publish completion messages directly.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
try:
|
||||
await rabbitmq.connect()
|
||||
await rabbitmq.publish_message(
|
||||
routing_key="operation.complete",
|
||||
message=message.model_dump_json(),
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
finally:
|
||||
await rabbitmq.disconnect()
|
||||
@@ -44,6 +44,20 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=1000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
@@ -82,6 +96,14 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -4,16 +4,19 @@ import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -81,6 +84,14 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -366,6 +377,267 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise NotFoundError(f"Task {task_id} not found or access denied.")
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
chunk_count = 0
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
logger.info(
|
||||
f"Task stream completed for task {task_id}, "
|
||||
f"chunk_count={chunk_count}"
|
||||
)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
# Publish result to stream registry
|
||||
from .response_model import StreamToolOutputAvailable
|
||||
|
||||
result_output = request.result if request.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
from . import service as svc
|
||||
|
||||
result_str = (
|
||||
request.result
|
||||
if isinstance(request.result, str)
|
||||
else (
|
||||
orjson.dumps(request.result).decode("utf-8")
|
||||
if request.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await svc._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
else:
|
||||
# Publish error to stream registry
|
||||
from .response_model import StreamError
|
||||
|
||||
error_msg = request.error or "Operation failed"
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
# Send finish event to end the stream
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
from . import service as svc
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=request.error,
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -1610,8 +1611,9 @@ async def _yield_tool_call(
|
||||
)
|
||||
return
|
||||
|
||||
# Generate operation ID
|
||||
# Generate operation ID and task ID
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
task_id = str(uuid_module.uuid4())
|
||||
|
||||
# Build a user-friendly message based on tool and arguments
|
||||
if tool_name == "create_agent":
|
||||
@@ -1654,6 +1656,16 @@ async def _yield_tool_call(
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
try:
|
||||
# Create task in stream registry for SSE reconnection support
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Save assistant message with tool_call FIRST (required by LLM)
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
@@ -1675,23 +1687,27 @@ async def _yield_tool_call(
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} for tool {tool_name} "
|
||||
f"in session {session.session_id}"
|
||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||
f"for tool {tool_name} in session {session.session_id}"
|
||||
)
|
||||
|
||||
# Store task reference in module-level set to prevent GC before completion
|
||||
task = asyncio.create_task(
|
||||
_execute_long_running_tool(
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
tool_name=tool_name,
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Associate the asyncio task with the stream registry task
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
except Exception as e:
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
@@ -1709,6 +1725,11 @@ async def _yield_tool_call(
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
@@ -1722,6 +1743,7 @@ async def _yield_tool_call(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id, # Include task_id for SSE reconnection
|
||||
).model_dump_json(),
|
||||
success=True,
|
||||
)
|
||||
@@ -1791,6 +1813,9 @@ async def _execute_long_running_tool(
|
||||
|
||||
This function runs independently of the SSE connection, so the operation
|
||||
survives if the user closes their browser tab.
|
||||
|
||||
NOTE: This is the legacy function without stream registry support.
|
||||
Use _execute_long_running_tool_with_streaming for new implementations.
|
||||
"""
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
@@ -1838,6 +1863,128 @@ async def _execute_long_running_tool(
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _execute_long_running_tool_with_streaming(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
) -> None:
|
||||
"""Execute a long-running tool with stream registry support for SSE reconnection.
|
||||
|
||||
This function runs independently of the SSE connection, publishes progress
|
||||
to the stream registry, and survives if the user closes their browser tab.
|
||||
Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming.
|
||||
|
||||
If the external service returns a 202 Accepted (async), this function exits
|
||||
early and lets the RabbitMQ completion consumer handle the rest.
|
||||
"""
|
||||
# Track whether we delegated to async processing - if so, the RabbitMQ
|
||||
# completion consumer will handle cleanup, not us
|
||||
delegated_to_async = False
|
||||
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
enriched_parameters = {
|
||||
**parameters,
|
||||
"_operation_id": operation_id,
|
||||
"_task_id": task_id,
|
||||
}
|
||||
|
||||
# Execute the actual tool
|
||||
result = await execute_tool(
|
||||
tool_name=tool_name,
|
||||
parameters=enriched_parameters,
|
||||
tool_call_id=tool_call_id,
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Check if the tool result indicates async processing
|
||||
# (e.g., Agent Generator returned 202 Accepted)
|
||||
try:
|
||||
result_data = orjson.loads(result.output) if result.output else {}
|
||||
if result_data.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Tool {tool_name} delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||
f"RabbitMQ completion consumer will handle the rest."
|
||||
)
|
||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||
# The RabbitMQ consumer will handle everything when the external
|
||||
# service completes and publishes to the queue
|
||||
delegated_to_async = True
|
||||
return
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not async - continue normally
|
||||
|
||||
# Publish tool result to stream registry
|
||||
await stream_registry.publish_chunk(task_id, result)
|
||||
|
||||
# Update the pending message with result
|
||||
result_str = (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else orjson.dumps(result.output).decode("utf-8")
|
||||
)
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Background tool {tool_name} completed for session {session_id} "
|
||||
f"(task_id={task_id})"
|
||||
)
|
||||
|
||||
# Generate LLM continuation and stream chunks to registry
|
||||
await _generate_llm_continuation_with_streaming(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=str(e)),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
finally:
|
||||
# Only cleanup if we didn't delegate to async processing
|
||||
# For async path, the RabbitMQ completion consumer handles cleanup
|
||||
if not delegated_to_async:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
@@ -1964,3 +2111,128 @@ async def _generate_llm_continuation(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def _generate_llm_continuation_with_streaming(
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
task_id: str,
|
||||
) -> None:
|
||||
"""Generate an LLM response with streaming to the stream registry.
|
||||
|
||||
This is called by background tasks to continue the conversation
|
||||
after a tool result is saved. Chunks are published to the stream registry
|
||||
so reconnecting clients can receive them.
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for LLM continuation")
|
||||
return
|
||||
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
"environment": settings.config.app_env.value,
|
||||
},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
# Publish start event
|
||||
await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id))
|
||||
await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id))
|
||||
|
||||
# Stream the response
|
||||
stream = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assistant_content = ""
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
delta = chunk.choices[0].delta.content
|
||||
assistant_content += delta
|
||||
# Publish delta to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamTextDelta(id=text_block_id, delta=delta),
|
||||
)
|
||||
|
||||
# Publish end events
|
||||
await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id))
|
||||
|
||||
if assistant_content:
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
fresh_session = await get_chat_session(session_id, user_id)
|
||||
if not fresh_session:
|
||||
logger.error(
|
||||
f"Session {session_id} disappeared during LLM continuation"
|
||||
)
|
||||
return
|
||||
|
||||
# Save assistant message to database
|
||||
assistant_message = ChatMessage(
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
logger.info(
|
||||
f"Generated streaming LLM continuation for session {session_id} "
|
||||
f"(task_id={task_id}), response length: {len(assistant_content)}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Streaming LLM continuation returned empty response for {session_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||
)
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
@@ -0,0 +1,648 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It supports:
|
||||
- Creating tasks with unique IDs for long-running operations
|
||||
- Publishing stream messages to both Redis Streams and in-memory queues
|
||||
- Subscribing to tasks with replay of missed messages
|
||||
- Looking up tasks by operation_id for webhook callbacks
|
||||
- Cross-pod real-time delivery via Redis pub/sub
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track active pub/sub listeners for cross-pod delivery
|
||||
_pubsub_listeners: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue)
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
# Lock for atomic status checks and subscriber management
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
# Set of subscriber queues for fan-out
|
||||
subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set)
|
||||
|
||||
|
||||
# Module-level registry for active tasks
|
||||
_active_tasks: dict[str, ActiveTask] = {}
|
||||
|
||||
# Redis key patterns
|
||||
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
|
||||
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
|
||||
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{TASK_META_PREFIX}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{TASK_STREAM_PREFIX}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{TASK_OP_PREFIX}{operation_id}"
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task cross-pod delivery."""
|
||||
return f"{TASK_PUBSUB_PREFIX}{task_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in memory and Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store in memory registry
|
||||
_active_tasks[task_id] = task
|
||||
|
||||
# Store metadata in Redis for durability
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.info(
|
||||
f"Created streaming task {task_id} for operation {operation_id} "
|
||||
f"in session {session_id}"
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to the task's stream.
|
||||
|
||||
Delivers to in-memory subscribers first (for real-time), then persists to
|
||||
Redis Stream (for replay). This order ensures live subscribers get messages
|
||||
even if Redis temporarily fails.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if
|
||||
Redis persistence failed
|
||||
"""
|
||||
# Deliver to in-memory subscribers FIRST for real-time updates
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, dropping chunk"
|
||||
)
|
||||
|
||||
# Then persist to Redis Stream for replay (with error handling)
|
||||
message_id = "0-0"
|
||||
chunk_json = chunk.model_dump_json()
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Add to Redis Stream with auto-generated ID
|
||||
# The ID format is "timestamp-sequence" which gives us ordering
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Publish to pub/sub for cross-pod real-time delivery
|
||||
pubsub_channel = _get_task_pubsub_channel(task_id)
|
||||
await redis.publish(pubsub_channel, chunk_json)
|
||||
|
||||
logger.debug(f"Published chunk to task {task_id}, message_id={message_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist chunk to Redis for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
# Check in-memory first
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Validate ownership
|
||||
if user_id and task.user_id and task.user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task.user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Create a new queue for this subscriber
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
|
||||
# Replay from Redis Stream
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Track the last message ID we've seen for gap detection
|
||||
replay_last_id = last_message_id
|
||||
|
||||
# Read all messages from stream starting after last_message_id
|
||||
# xread returns messages with ID > last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
# messages format: [[stream_name, [(id, {data: json}), ...]]]
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
# Track the last message ID we've processed
|
||||
replay_last_id = (
|
||||
msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
)
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
# Reconstruct the appropriate response type
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Atomically check status and register subscriber under lock
|
||||
# This prevents race condition where task completes between check and subscribe
|
||||
should_start_pubsub = False
|
||||
async with task.lock:
|
||||
if task.status == "running":
|
||||
# Register this subscriber for live updates
|
||||
task.subscribers.add(subscriber_queue)
|
||||
# Start pub/sub listener if this is the first subscriber
|
||||
should_start_pubsub = len(task.subscribers) == 1
|
||||
logger.debug(
|
||||
f"Registered subscriber for task {task_id}, "
|
||||
f"total subscribers: {len(task.subscribers)}"
|
||||
)
|
||||
else:
|
||||
# Task is done, add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
# After registering, do a second read to catch any messages published
|
||||
# between the first read and registration (closes the race window)
|
||||
if task.status == "running":
|
||||
gap_messages = await redis.xread(
|
||||
{stream_key: replay_last_id}, block=0, count=1000
|
||||
)
|
||||
if gap_messages:
|
||||
for _stream_name, stream_messages in gap_messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay gap message: {e}")
|
||||
|
||||
# Start pub/sub listener outside the lock to avoid deadlocks
|
||||
if should_start_pubsub:
|
||||
await start_pubsub_listener(task_id)
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
# Try to load from Redis if not in memory
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.warning(f"Task {task_id} not found in memory or Redis")
|
||||
return None
|
||||
|
||||
# Validate ownership
|
||||
task_user_id = meta.get(b"user_id", b"").decode() or None
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} attempted to subscribe to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Replay from Redis Stream only (task is not in memory, so it's completed/crashed)
|
||||
subscriber_queue = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Read all messages starting after last_message_id
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for _msg_id, msg_data in stream_messages:
|
||||
if b"data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data[b"data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
# Add finish marker since task is not active
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> None:
|
||||
"""Mark a task as completed and publish final event.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
|
||||
if task:
|
||||
# Acquire lock to prevent new subscribers during completion
|
||||
async with task.lock:
|
||||
task.status = status
|
||||
# Send finish event directly to all current subscribers
|
||||
finish_event = StreamFinish()
|
||||
for subscriber_queue in task.subscribers:
|
||||
try:
|
||||
subscriber_queue.put_nowait(finish_event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id} during completion"
|
||||
)
|
||||
# Clear subscribers since task is done
|
||||
task.subscribers.clear()
|
||||
|
||||
# Stop pub/sub listener since task is done
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
# Also publish to Redis Stream for replay (and pub/sub for cross-pod)
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
|
||||
# Remove from active tasks after a short delay to allow subscribers to finish
|
||||
async def _cleanup():
|
||||
await asyncio.sleep(5)
|
||||
_active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up task {task_id} from memory")
|
||||
|
||||
asyncio.create_task(_cleanup())
|
||||
|
||||
# Update Redis metadata
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
await redis.hset(meta_key, "status", status) # type: ignore[misc]
|
||||
|
||||
logger.info(f"Marked task {task_id} as {status}")
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
for task in _active_tasks.values():
|
||||
if task.operation_id == operation_id:
|
||||
return task
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if task_id:
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
# Check if task is in memory
|
||||
if task_id_str in _active_tasks:
|
||||
return _active_tasks[task_id_str]
|
||||
|
||||
# Load metadata from Redis
|
||||
meta_key = _get_task_meta_key(task_id_str)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if meta:
|
||||
# Reconstruct task object (not fully active, but has metadata)
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=operation_id,
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
# Check in-memory first
|
||||
if task_id in _active_tasks:
|
||||
return _active_tasks[task_id]
|
||||
|
||||
# Try Redis lookup
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if meta:
|
||||
return ActiveTask(
|
||||
task_id=meta.get(b"task_id", b"").decode(),
|
||||
session_id=meta.get(b"session_id", b"").decode(),
|
||||
user_id=meta.get(b"user_id", b"").decode() or None,
|
||||
tool_call_id=meta.get(b"tool_call_id", b"").decode(),
|
||||
tool_name=meta.get(b"tool_name", b"").decode(),
|
||||
operation_id=meta.get(b"operation_id", b"").decode(),
|
||||
status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
|
||||
try:
|
||||
if chunk_type == ResponseType.START.value:
|
||||
return StreamStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.FINISH.value:
|
||||
return StreamFinish(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_START.value:
|
||||
return StreamTextStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_DELTA.value:
|
||||
return StreamTextDelta(**chunk_data)
|
||||
elif chunk_type == ResponseType.TEXT_END.value:
|
||||
return StreamTextEnd(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_START.value:
|
||||
return StreamToolInputStart(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_INPUT_AVAILABLE.value:
|
||||
return StreamToolInputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.TOOL_OUTPUT_AVAILABLE.value:
|
||||
return StreamToolOutputAvailable(**chunk_data)
|
||||
elif chunk_type == ResponseType.ERROR.value:
|
||||
return StreamError(**chunk_data)
|
||||
elif chunk_type == ResponseType.USAGE.value:
|
||||
return StreamUsage(**chunk_data)
|
||||
elif chunk_type == ResponseType.HEARTBEAT.value:
|
||||
return StreamHeartbeat(**chunk_data)
|
||||
else:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Associate an asyncio.Task with an ActiveTask.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to associate
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
task.asyncio_task = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Unsubscribe a queue from a task's stream.
|
||||
|
||||
Should be called when a client disconnects to clean up resources.
|
||||
Also stops the pub/sub listener if there are no more local subscribers.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to unsubscribe from
|
||||
subscriber_queue: The queue to remove from subscribers
|
||||
"""
|
||||
task = _active_tasks.get(task_id)
|
||||
if task:
|
||||
async with task.lock:
|
||||
task.subscribers.discard(subscriber_queue)
|
||||
remaining = len(task.subscribers)
|
||||
logger.debug(
|
||||
f"Unsubscribed from task {task_id}, "
|
||||
f"remaining subscribers: {remaining}"
|
||||
)
|
||||
# Stop pub/sub listener if no more local subscribers
|
||||
if remaining == 0:
|
||||
await stop_pubsub_listener(task_id)
|
||||
|
||||
|
||||
async def start_pubsub_listener(task_id: str) -> None:
|
||||
"""Start listening to Redis pub/sub for cross-pod delivery.
|
||||
|
||||
This enables real-time updates when another pod publishes chunks for a task
|
||||
that has local subscribers on this pod.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
"""
|
||||
if task_id in _pubsub_listeners:
|
||||
return # Already listening
|
||||
|
||||
task = _active_tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
async def _listener():
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
pubsub = redis.pubsub()
|
||||
channel = _get_task_pubsub_channel(task_id)
|
||||
await pubsub.subscribe(channel)
|
||||
logger.debug(f"Started pub/sub listener for task {task_id}")
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(message["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
# Deliver to local subscribers
|
||||
local_task = _active_tasks.get(task_id)
|
||||
if local_task:
|
||||
async with local_task.lock:
|
||||
for queue in local_task.subscribers:
|
||||
try:
|
||||
queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
# Stop listening if this was a finish event
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing pub/sub message: {e}")
|
||||
|
||||
await pubsub.unsubscribe(channel)
|
||||
await pubsub.close()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Pub/sub listener error for task {task_id}: {e}")
|
||||
finally:
|
||||
_pubsub_listeners.pop(task_id, None)
|
||||
logger.debug(f"Stopped pub/sub listener for task {task_id}")
|
||||
|
||||
listener_task = asyncio.create_task(_listener())
|
||||
_pubsub_listeners[task_id] = listener_task
|
||||
|
||||
|
||||
async def stop_pubsub_listener(task_id: str) -> None:
|
||||
"""Stop the pub/sub listener for a task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to stop listening for
|
||||
"""
|
||||
listener = _pubsub_listeners.pop(task_id, None)
|
||||
if listener and not listener.done():
|
||||
listener.cancel()
|
||||
try:
|
||||
await listener
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.debug(f"Cancelled pub/sub listener for task {task_id}")
|
||||
@@ -57,21 +57,32 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any]
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
async def generate_agent(
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(instructions)
|
||||
result = await generate_agent_external(instructions, operation_id, task_id)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
if result:
|
||||
# Check if it's an error response - pass through as-is
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
@@ -256,7 +267,10 @@ async def get_agent_as_json(
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -268,14 +282,18 @@ async def generate_agent_patch(
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, operation_id, task_id
|
||||
)
|
||||
|
||||
@@ -207,21 +207,42 @@ async def decompose_goal_external(
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -251,27 +272,48 @@ async def generate_agent_external(
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
operation_id: Operation ID for async processing (enables RabbitMQ callback)
|
||||
task_id: Task ID for async processing (enables RabbitMQ callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -96,6 +97,10 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -192,7 +197,11 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -232,6 +241,19 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
|
||||
@@ -16,6 +16,7 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -103,6 +104,10 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -134,7 +139,12 @@ class EditAgentTool(BaseTool):
|
||||
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -153,6 +163,19 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
|
||||
@@ -352,11 +352,15 @@ class OperationStartedResponse(ToolResponseBase):
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
@@ -380,3 +384,20 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The RabbitMQ completion consumer
|
||||
will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
@@ -40,6 +40,10 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for RabbitMQ notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
@@ -70,41 +69,16 @@ export function useCopilotShell() {
|
||||
});
|
||||
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
@@ -114,7 +88,12 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function startNewChat() {
|
||||
function handleNewChatClick() {
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -123,32 +102,6 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -4,53 +4,25 @@ import { create } from "zustand";
|
||||
|
||||
interface CopilotStoreState {
|
||||
isStreaming: boolean;
|
||||
isSwitchingSession: boolean;
|
||||
isCreatingSession: boolean;
|
||||
isInterruptModalOpen: boolean;
|
||||
pendingAction: (() => void) | null;
|
||||
}
|
||||
|
||||
interface CopilotStoreActions {
|
||||
setIsStreaming: (isStreaming: boolean) => void;
|
||||
setIsSwitchingSession: (isSwitchingSession: boolean) => void;
|
||||
setIsCreatingSession: (isCreating: boolean) => void;
|
||||
openInterruptModal: (onConfirm: () => void) => void;
|
||||
confirmInterrupt: () => void;
|
||||
cancelInterrupt: () => void;
|
||||
}
|
||||
|
||||
type CopilotStore = CopilotStoreState & CopilotStoreActions;
|
||||
|
||||
export const useCopilotStore = create<CopilotStore>((set, get) => ({
|
||||
export const useCopilotStore = create<CopilotStore>((set) => ({
|
||||
isStreaming: false,
|
||||
isSwitchingSession: false,
|
||||
isCreatingSession: false,
|
||||
isInterruptModalOpen: false,
|
||||
pendingAction: null,
|
||||
|
||||
setIsStreaming(isStreaming) {
|
||||
set({ isStreaming });
|
||||
},
|
||||
|
||||
setIsSwitchingSession(isSwitchingSession) {
|
||||
set({ isSwitchingSession });
|
||||
},
|
||||
|
||||
setIsCreatingSession(isCreatingSession) {
|
||||
set({ isCreatingSession });
|
||||
},
|
||||
|
||||
openInterruptModal(onConfirm) {
|
||||
set({ isInterruptModalOpen: true, pendingAction: onConfirm });
|
||||
},
|
||||
|
||||
confirmInterrupt() {
|
||||
const { pendingAction } = get();
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
if (pendingAction) pendingAction();
|
||||
},
|
||||
|
||||
cancelInterrupt() {
|
||||
set({ isInterruptModalOpen: false, pendingAction: null });
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -5,15 +5,10 @@ import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
const { state, handlers } = useCopilotPage();
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
@@ -40,42 +35,6 @@ export default function CopilotPage() {
|
||||
onSessionNotFound={handleSessionNotFound}
|
||||
onStreamingChange={handleStreamingChange}
|
||||
/>
|
||||
<Dialog
|
||||
title="Interrupt current chat?"
|
||||
styling={{ maxWidth: 300, width: "100%" }}
|
||||
controlled={{
|
||||
isOpen: isInterruptModalOpen,
|
||||
set: (open) => {
|
||||
if (!open) cancelInterrupt();
|
||||
},
|
||||
}}
|
||||
onClose={cancelInterrupt}
|
||||
>
|
||||
<Dialog.Content>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text variant="body">
|
||||
The current chat response will be interrupted. Are you sure you
|
||||
want to continue?
|
||||
</Text>
|
||||
<Dialog.Footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={cancelInterrupt}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="primary"
|
||||
onClick={confirmInterrupt}
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
</Dialog.Footer>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
@@ -26,20 +25,12 @@ export function useCopilotPage() {
|
||||
const queryClient = useQueryClient();
|
||||
const { user, isLoggedIn, isUserLoading } = useSupabase();
|
||||
const { toast } = useToast();
|
||||
const { completeStep } = useOnboarding();
|
||||
|
||||
const { urlSessionId, setUrlSessionId } = useCopilotSessionId();
|
||||
const setIsStreaming = useCopilotStore((s) => s.setIsStreaming);
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
/**
|
||||
* SSE Proxy for task stream reconnection.
|
||||
*
|
||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
||||
* background task's stream. It replays missed messages from Redis Streams and
|
||||
* subscribes to live updates if the task is still running.
|
||||
*
|
||||
* Client contract:
|
||||
* 1. When receiving an operation_started event, store the task_id
|
||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
* 3. Messages are replayed from the last_message_id position
|
||||
* 4. Stream ends when "finish" event is received
|
||||
*/
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ taskId: string }> },
|
||||
) {
|
||||
const { taskId } = await params;
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||
|
||||
try {
|
||||
// Get auth token from server-side session
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
// Build backend URL
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||
|
||||
// Forward request to backend with auth header
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(streamUrl.toString(), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return new Response(error, {
|
||||
status: response.status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
// Return the SSE stream directly
|
||||
return new Response(response.body, {
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
Connection: "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Task stream proxy error:", error);
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: "Failed to connect to task stream",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
}),
|
||||
{
|
||||
status: 500,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -939,6 +939,63 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/operations/{operation_id}/complete": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Complete Operation",
|
||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||
"operationId": "postV2CompleteOperation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "operation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Operation Id" }
|
||||
},
|
||||
{
|
||||
"name": "x-api-key",
|
||||
"in": "header",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "X-Api-Key"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Postv2Completeoperation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -1195,6 +1252,94 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Task Status",
|
||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2GetTaskStatus",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Getv2Gettaskstatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Task",
|
||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2StreamTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
},
|
||||
{
|
||||
"name": "last_message_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
"default": "0-0",
|
||||
"title": "Last Message Id"
|
||||
},
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -8804,6 +8949,27 @@
|
||||
],
|
||||
"title": "OnboardingStep"
|
||||
},
|
||||
"OperationCompleteRequest": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"result": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Result"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "OperationCompleteRequest",
|
||||
"description": "Request model for external completion webhook."
|
||||
},
|
||||
"Pagination": {
|
||||
"properties": {
|
||||
"total_items": {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -25,7 +24,6 @@ export function Chat({
|
||||
}: ChatProps) {
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||
const {
|
||||
messages,
|
||||
isLoading,
|
||||
@@ -53,8 +51,7 @@ export function Chat({
|
||||
isCreating,
|
||||
]);
|
||||
|
||||
const shouldShowLoader =
|
||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
@@ -66,21 +63,19 @@ export function Chat({
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
{isSwitchingSession
|
||||
? "Switching chat..."
|
||||
: "Loading your chat..."}
|
||||
Loading your chat...
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && !isSwitchingSession && (
|
||||
{error && !isLoading && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||
{sessionId && !isLoading && !error && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
# SSE Reconnection Contract for Long-Running Operations
|
||||
|
||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
||||
|
||||
## Overview
|
||||
|
||||
When a user triggers a long-running operation (like agent generation), the backend:
|
||||
1. Spawns a background task that survives SSE disconnections
|
||||
2. Returns an `operation_started` response with a `task_id`
|
||||
3. Stores stream messages in Redis Streams for replay
|
||||
|
||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
||||
|
||||
## Client-Side Flow
|
||||
|
||||
### 1. Receiving Operation Started
|
||||
|
||||
When you receive an `operation_started` tool response:
|
||||
|
||||
```typescript
|
||||
// The response includes a task_id for reconnection
|
||||
{
|
||||
type: "operation_started",
|
||||
tool_name: "generate_agent",
|
||||
operation_id: "uuid-...",
|
||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
||||
message: "Operation started. You can close this tab."
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Storing Task Info
|
||||
|
||||
Use the chat store to track the active task:
|
||||
|
||||
```typescript
|
||||
import { useChatStore } from "./chat-store";
|
||||
|
||||
// When operation_started is received:
|
||||
useChatStore.getState().setActiveTask(sessionId, {
|
||||
taskId: response.task_id,
|
||||
operationId: response.operation_id,
|
||||
toolName: response.tool_name,
|
||||
lastMessageId: "0",
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Reconnecting to a Task
|
||||
|
||||
To reconnect (e.g., after page refresh or tab reopen):
|
||||
|
||||
```typescript
|
||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
||||
|
||||
// Check if there's an active task for this session
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
|
||||
if (activeTask) {
|
||||
// Reconnect to the task stream
|
||||
await reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId, // Resume from last position
|
||||
(chunk) => {
|
||||
// Handle incoming chunks
|
||||
console.log("Received chunk:", chunk);
|
||||
}
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tracking Message Position
|
||||
|
||||
To enable precise replay, update the last message ID as chunks arrive:
|
||||
|
||||
```typescript
|
||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
||||
|
||||
function handleChunk(chunk: StreamChunk) {
|
||||
// If chunk has an index/id, track it
|
||||
if (chunk.idx !== undefined) {
|
||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Task Stream Reconnection
|
||||
```
|
||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
```
|
||||
|
||||
- `taskId`: The task ID from `operation_started`
|
||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
||||
|
||||
Returns: SSE stream of missed messages + live updates
|
||||
|
||||
## Chunk Types
|
||||
|
||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `start` | Message lifecycle start |
|
||||
| `text-delta` | Streaming text content |
|
||||
| `text-end` | Text block completed |
|
||||
| `tool-output-available` | Tool result available |
|
||||
| `finish` | Stream completed |
|
||||
| `error` | Error occurred |
|
||||
|
||||
## Error Handling
|
||||
|
||||
If reconnection fails:
|
||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
||||
2. Fall back to polling the session for final state
|
||||
3. Show appropriate UI message to user
|
||||
|
||||
## Persistence Considerations
|
||||
|
||||
For robust reconnection across browser restarts:
|
||||
|
||||
```typescript
|
||||
// Store in localStorage/sessionStorage
|
||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
||||
|
||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
tasks[sessionId] = task;
|
||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
||||
}
|
||||
|
||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
}
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
The following backend settings affect reconnection behavior:
|
||||
|
||||
| Setting | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
||||
| `stream_max_length` | 1000 | Max messages per stream |
|
||||
|
||||
## Testing
|
||||
|
||||
To test reconnection locally:
|
||||
|
||||
1. Start a long-running operation (e.g., agent generation)
|
||||
2. Note the `task_id` from the `operation_started` response
|
||||
3. Close the browser tab
|
||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
||||
5. Verify that missed messages are replayed
|
||||
|
||||
See the main README for full local development setup.
|
||||
@@ -8,15 +8,68 @@ import type {
|
||||
StreamResult,
|
||||
StreamStatus,
|
||||
} from "./chat-types";
|
||||
import { executeStream } from "./stream-executor";
|
||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||
|
||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
const ACTIVE_TASKS_STORAGE_KEY = "chat_active_tasks";
|
||||
const TASK_TTL = 60 * 60 * 1000; // 1 hour - tasks expire after this
|
||||
|
||||
/**
|
||||
* Tracks active task info for SSE reconnection.
|
||||
* When a long-running operation starts, we store this so clients can reconnect
|
||||
* if the browser tab is closed and reopened.
|
||||
*/
|
||||
export interface ActiveTaskInfo {
|
||||
taskId: string;
|
||||
sessionId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
lastMessageId: string; // Last processed message ID for replay (Redis Stream format: "0-0")
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
/** Load active tasks from localStorage */
|
||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||
if (typeof window === "undefined") return new Map();
|
||||
try {
|
||||
const stored = localStorage.getItem(ACTIVE_TASKS_STORAGE_KEY);
|
||||
if (!stored) return new Map();
|
||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||
const now = Date.now();
|
||||
const tasks = new Map<string, ActiveTaskInfo>();
|
||||
// Filter out expired tasks
|
||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||
if (now - task.startedAt < TASK_TTL) {
|
||||
tasks.set(sessionId, task);
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
} catch {
|
||||
return new Map();
|
||||
}
|
||||
}
|
||||
|
||||
/** Save active tasks to localStorage */
|
||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
const obj: Record<string, ActiveTaskInfo> = {};
|
||||
for (const [sessionId, task] of tasks) {
|
||||
obj[sessionId] = task;
|
||||
}
|
||||
localStorage.setItem(ACTIVE_TASKS_STORAGE_KEY, JSON.stringify(obj));
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
|
||||
interface ChatStoreState {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
activeSessions: Set<string>;
|
||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
||||
activeTasks: Map<string, ActiveTaskInfo>;
|
||||
}
|
||||
|
||||
interface ChatStoreActions {
|
||||
@@ -41,6 +94,24 @@ interface ChatStoreActions {
|
||||
unregisterActiveSession: (sessionId: string) => void;
|
||||
isSessionActive: (sessionId: string) => boolean;
|
||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||
/** Track active task for SSE reconnection */
|
||||
setActiveTask: (
|
||||
sessionId: string,
|
||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
||||
) => void;
|
||||
/** Get active task for a session */
|
||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
||||
/** Clear active task when operation completes */
|
||||
clearActiveTask: (sessionId: string) => void;
|
||||
/** Reconnect to an existing task stream */
|
||||
reconnectToTask: (
|
||||
sessionId: string,
|
||||
taskId: string,
|
||||
lastMessageId?: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => Promise<void>;
|
||||
/** Update last message ID for a task (for tracking replay position) */
|
||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
||||
}
|
||||
|
||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||
@@ -76,6 +147,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
completedStreams: new Map(),
|
||||
activeSessions: new Set(),
|
||||
streamCompleteCallbacks: new Set(),
|
||||
activeTasks: loadPersistedTasks(),
|
||||
|
||||
startStream: async function startStream(
|
||||
sessionId,
|
||||
@@ -286,4 +358,139 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||
};
|
||||
},
|
||||
|
||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
||||
const state = get();
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...taskInfo,
|
||||
sessionId,
|
||||
startedAt: Date.now(),
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
getActiveTask: function getActiveTask(sessionId) {
|
||||
return get().activeTasks.get(sessionId);
|
||||
},
|
||||
|
||||
clearActiveTask: function clearActiveTask(sessionId) {
|
||||
const state = get();
|
||||
if (!state.activeTasks.has(sessionId)) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
reconnectToTask: async function reconnectToTask(
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId = "0-0", // Redis Stream ID format
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
try {
|
||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||
} finally {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(
|
||||
currentState.streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
// Clear active task on completion
|
||||
const taskState = get();
|
||||
const newActiveTasks = new Map(taskState.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
||||
sessionId,
|
||||
lastMessageId,
|
||||
) {
|
||||
const state = get();
|
||||
const task = state.activeTasks.get(sessionId);
|
||||
if (!task) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...task,
|
||||
lastMessageId,
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -23,6 +23,12 @@ export interface HandlerDependencies {
|
||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
onOperationStarted?: () => void;
|
||||
onActiveTaskStarted?: (taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||
@@ -164,9 +170,19 @@ export function handleToolResponse(
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Trigger polling when operation_started is received
|
||||
// Trigger polling and store task info when operation_started is received
|
||||
if (responseMessage.type === "operation_started") {
|
||||
deps.onOperationStarted?.();
|
||||
// Store task info for SSE reconnection if taskId is present
|
||||
const taskId = (responseMessage as any).taskId;
|
||||
if (taskId && deps.onActiveTaskStarted) {
|
||||
deps.onActiveTaskStarted({
|
||||
taskId,
|
||||
operationId: (responseMessage as any).operationId || "",
|
||||
toolName: (responseMessage as any).toolName || "",
|
||||
toolCallId: (responseMessage as any).toolId || "",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
deps.setMessages((prev) => {
|
||||
|
||||
@@ -349,6 +349,7 @@ export function parseToolResponse(
|
||||
toolName: (parsedResult.tool_name as string) || toolName,
|
||||
toolId,
|
||||
operationId: (parsedResult.operation_id as string) || "",
|
||||
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
||||
message:
|
||||
(parsedResult.message as string) ||
|
||||
"Operation started. You can close this tab.",
|
||||
|
||||
@@ -65,8 +65,27 @@ export function useChatContainer({
|
||||
} = useChatStream();
|
||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||
|
||||
// Callback to store active task info for SSE reconnection
|
||||
function handleActiveTaskStarted(taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) {
|
||||
if (!sessionId) return;
|
||||
setActiveTask(sessionId, {
|
||||
taskId: taskInfo.taskId,
|
||||
operationId: taskInfo.operationId,
|
||||
toolName: taskInfo.toolName,
|
||||
lastMessageId: "0-0", // Redis Stream ID format for full replay
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(
|
||||
function handleSessionChange() {
|
||||
if (sessionId === previousSessionIdRef.current) return;
|
||||
@@ -85,6 +104,34 @@ export function useChatContainer({
|
||||
|
||||
if (!sessionId) return;
|
||||
|
||||
// Check if there's an active task for this session that we should reconnect to
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
if (activeTask) {
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
hasResponseRef,
|
||||
setMessages,
|
||||
setIsRegionBlockedModalOpen,
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
// Reconnect to the task stream
|
||||
reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId,
|
||||
dispatcher,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise check for an in-memory active stream
|
||||
const activeStream = activeStreams.get(sessionId);
|
||||
if (!activeStream || activeStream.status !== "streaming") return;
|
||||
|
||||
@@ -98,6 +145,7 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
setIsStreamingInitiated(true);
|
||||
@@ -110,6 +158,8 @@ export function useChatContainer({
|
||||
activeStreams,
|
||||
subscribeToStream,
|
||||
onOperationStarted,
|
||||
getActiveTask,
|
||||
reconnectToTask,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -225,6 +275,7 @@ export function useChatContainer({
|
||||
sessionId,
|
||||
setIsStreamingInitiated,
|
||||
onOperationStarted,
|
||||
onActiveTaskStarted: handleActiveTaskStarted,
|
||||
});
|
||||
|
||||
try {
|
||||
|
||||
@@ -111,6 +111,7 @@ export type ChatMessageData =
|
||||
toolName: string;
|
||||
toolId: string;
|
||||
operationId: string;
|
||||
taskId?: string; // For SSE reconnection
|
||||
message: string;
|
||||
timestamp?: string | Date;
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@ import {
|
||||
parseSSELine,
|
||||
} from "./stream-utils";
|
||||
|
||||
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||
stream.chunks.push(chunk);
|
||||
function notifySubscribers(
|
||||
stream: ActiveStream,
|
||||
chunk: StreamChunk,
|
||||
skipStore = false,
|
||||
) {
|
||||
if (!skipStore) {
|
||||
stream.chunks.push(chunk);
|
||||
}
|
||||
for (const callback of stream.onChunkCallbacks) {
|
||||
try {
|
||||
callback(chunk);
|
||||
@@ -140,3 +146,124 @@ export async function executeStream(
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an existing task stream.
|
||||
*
|
||||
* This is used when a client wants to resume receiving updates from a
|
||||
* long-running background task. Messages are replayed from the last_message_id
|
||||
* position, allowing clients to catch up on missed events.
|
||||
*
|
||||
* @param stream - The active stream state
|
||||
* @param taskId - The task ID to reconnect to
|
||||
* @param lastMessageId - The last message ID received (for replay)
|
||||
* @param retryCount - Current retry count
|
||||
*/
|
||||
export async function executeTaskReconnect(
|
||||
stream: ActiveStream,
|
||||
taskId: string,
|
||||
lastMessageId: string = "0",
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
const { abortController } = stream;
|
||||
|
||||
try {
|
||||
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const rawChunk = JSON.parse(data) as
|
||||
| StreamChunk
|
||||
| VercelStreamChunk;
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
);
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
"[StreamExecutor] Failed to parse task reconnect SSE chunk:",
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1);
|
||||
}
|
||||
|
||||
stream.status = "error";
|
||||
stream.error = err instanceof Error ? err : new Error("Task reconnect failed");
|
||||
notifySubscribers(stream, {
|
||||
type: "error",
|
||||
message: stream.error.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user