mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 10:55:14 -05:00
simplify and ensure agents are added to store
This commit is contained in:
@@ -20,17 +20,12 @@ from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Stream configuration
|
||||
COMPLETION_STREAM = "chat:completions"
|
||||
CONSUMER_GROUP = "chat_consumers"
|
||||
STREAM_MAX_LENGTH = 10000
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
@@ -69,17 +64,20 @@ class ChatCompletionConsumer:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
COMPLETION_STREAM,
|
||||
CONSUMER_GROUP,
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{CONSUMER_GROUP}' on stream '{COMPLETION_STREAM}'"
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(f"Consumer group '{CONSUMER_GROUP}' already exists")
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -134,9 +132,9 @@ class ChatCompletionConsumer:
|
||||
while self._running:
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=CONSUMER_GROUP,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={COMPLETION_STREAM: ">"},
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
@@ -161,7 +159,9 @@ class ChatCompletionConsumer:
|
||||
|
||||
# Acknowledge the message
|
||||
await redis.xack(
|
||||
COMPLETION_STREAM, CONSUMER_GROUP, entry_id
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -237,72 +237,8 @@ class ChatCompletionConsumer:
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
# Publish result to stream registry
|
||||
result_output = message.result if message.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database using our Prisma client
|
||||
result_str = (
|
||||
message.result
|
||||
if isinstance(message.result, str)
|
||||
else (
|
||||
orjson.dumps(message.result).decode("utf-8")
|
||||
if message.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
try:
|
||||
prisma = await self._ensure_prisma()
|
||||
await prisma.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": result_str},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message for session {task.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
@@ -310,47 +246,8 @@ class ChatCompletionConsumer:
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
error_msg = message.error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error using our Prisma client
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=message.error,
|
||||
)
|
||||
try:
|
||||
prisma = await self._ensure_prisma()
|
||||
await prisma.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": error_response.model_dump_json()},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}"
|
||||
)
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
@@ -399,8 +296,8 @@ async def publish_operation_complete(
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
COMPLETION_STREAM,
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=STREAM_MAX_LENGTH,
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
|
||||
def serialize_result(result: dict | str | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize (dict, string, or None)
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result:
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
return '{"status": "completed"}'
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"[COMPLETION] Cannot save agent: no user_id in task"
|
||||
)
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": agent_json, # Include the JSON so user can retry
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output
|
||||
result_output = result if result else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": result_str},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message for session {task.session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": task.session_id,
|
||||
"toolCallId": task.tool_call_id,
|
||||
},
|
||||
data={"content": error_response.model_dump_json()},
|
||||
)
|
||||
logger.info(
|
||||
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -50,9 +50,37 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=1000,
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
task_pubsub_prefix: str = Field(
|
||||
default="chat:task:pubsub:",
|
||||
description="Prefix for task pub/sub channel names",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
|
||||
@@ -5,7 +5,6 @@ import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -15,6 +14,7 @@ from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
@@ -704,81 +704,9 @@ async def complete_operation(
|
||||
)
|
||||
|
||||
if request.success:
|
||||
# Publish result to stream registry
|
||||
from .response_model import StreamToolOutputAvailable
|
||||
|
||||
result_output = request.result if request.result else {"status": "completed"}
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
from . import service as svc
|
||||
|
||||
result_str = (
|
||||
request.result
|
||||
if isinstance(request.result, str)
|
||||
else (
|
||||
orjson.dumps(request.result).decode("utf-8")
|
||||
if request.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=result_str,
|
||||
)
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
await svc._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
# Publish error to stream registry
|
||||
from .response_model import StreamError
|
||||
|
||||
error_msg = request.error or "Operation failed"
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
# Send finish event to end the stream
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
from . import service as svc
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=request.error,
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
@@ -31,6 +31,9 @@ from .response_model import StreamBaseResponse, StreamFinish
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
@@ -47,34 +50,24 @@ class ActiveTask:
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
# Redis key patterns
|
||||
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
|
||||
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
|
||||
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
|
||||
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for real-time delivery
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{TASK_META_PREFIX}{task_id}"
|
||||
return f"{config.task_meta_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{TASK_STREAM_PREFIX}{task_id}"
|
||||
return f"{config.task_stream_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{TASK_OP_PREFIX}{operation_id}"
|
||||
return f"{config.task_op_prefix}{operation_id}"
|
||||
|
||||
|
||||
def _get_task_pubsub_channel(task_id: str) -> str:
|
||||
"""Get Redis pub/sub channel for task real-time delivery."""
|
||||
return f"{TASK_PUBSUB_PREFIX}{task_id}"
|
||||
return f"{config.task_pubsub_prefix}{task_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
@@ -466,7 +459,9 @@ async def get_active_task_for_session(
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{TASK_META_PREFIX}*", count=100)
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Constants for the chat system.
|
||||
*
|
||||
* Centralizes magic strings and values used across chat components.
|
||||
*/
|
||||
|
||||
// LocalStorage keys
|
||||
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
||||
|
||||
// Redis Stream IDs
|
||||
export const INITIAL_MESSAGE_ID = "0";
|
||||
export const INITIAL_STREAM_ID = "0-0";
|
||||
|
||||
// TTL values (in milliseconds)
|
||||
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
||||
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
||||
@@ -1,6 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { create } from "zustand";
|
||||
import {
|
||||
ACTIVE_TASK_TTL_MS,
|
||||
COMPLETED_STREAM_TTL_MS,
|
||||
INITIAL_STREAM_ID,
|
||||
STORAGE_KEY_ACTIVE_TASKS,
|
||||
} from "./chat-constants";
|
||||
import type {
|
||||
ActiveStream,
|
||||
StreamChunk,
|
||||
@@ -10,10 +16,6 @@ import type {
|
||||
} from "./chat-types";
|
||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||
|
||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
const ACTIVE_TASKS_STORAGE_KEY = "chat_active_tasks";
|
||||
const TASK_TTL = 60 * 60 * 1000; // 1 hour - tasks expire after this
|
||||
|
||||
/**
|
||||
* Tracks active task info for SSE reconnection.
|
||||
* When a long-running operation starts, we store this so clients can reconnect
|
||||
@@ -32,14 +34,14 @@ export interface ActiveTaskInfo {
|
||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||
if (typeof window === "undefined") return new Map();
|
||||
try {
|
||||
const stored = localStorage.getItem(ACTIVE_TASKS_STORAGE_KEY);
|
||||
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
||||
if (!stored) return new Map();
|
||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||
const now = Date.now();
|
||||
const tasks = new Map<string, ActiveTaskInfo>();
|
||||
// Filter out expired tasks
|
||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||
if (now - task.startedAt < TASK_TTL) {
|
||||
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
||||
tasks.set(sessionId, task);
|
||||
}
|
||||
}
|
||||
@@ -57,7 +59,7 @@ function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||
for (const [sessionId, task] of tasks) {
|
||||
obj[sessionId] = task;
|
||||
}
|
||||
localStorage.setItem(ACTIVE_TASKS_STORAGE_KEY, JSON.stringify(obj));
|
||||
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
@@ -135,13 +137,73 @@ function cleanupExpiredStreams(
|
||||
const now = Date.now();
|
||||
const cleaned = new Map(completedStreams);
|
||||
for (const [sessionId, result] of cleaned) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
||||
cleaned.delete(sessionId);
|
||||
}
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up an existing stream for a session and move it to completed streams.
|
||||
* Returns updated maps for both active and completed streams.
|
||||
*/
|
||||
function cleanupExistingStream(
|
||||
sessionId: string,
|
||||
activeStreams: Map<string, ActiveStream>,
|
||||
completedStreams: Map<string, StreamResult>,
|
||||
callbacks: Set<StreamCompleteCallback>,
|
||||
): {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
} {
|
||||
const newActiveStreams = new Map(activeStreams);
|
||||
let newCompletedStreams = new Map(completedStreams);
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming" ? "completed" : existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
return { activeStreams: newActiveStreams, completedStreams: newCompletedStreams };
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new active stream with initial state.
|
||||
*/
|
||||
function createActiveStream(
|
||||
sessionId: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
): ActiveStream {
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
}
|
||||
|
||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
activeStreams: new Map(),
|
||||
completedStreams: new Map(),
|
||||
@@ -157,45 +219,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
// Clean up any existing stream for this session
|
||||
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
|
||||
cleanupExistingStream(
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
|
||||
// Create new stream
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
@@ -388,7 +424,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
reconnectToTask: async function reconnectToTask(
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId = "0-0", // Redis Stream ID format
|
||||
lastMessageId = INITIAL_STREAM_ID,
|
||||
onChunk,
|
||||
) {
|
||||
console.info("[SSE-RECONNECT] reconnectToTask called:", {
|
||||
@@ -398,43 +434,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
});
|
||||
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
|
||||
cleanupExistingStream(
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
|
||||
// Create new stream for reconnection
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
|
||||
@@ -94,3 +94,67 @@ export interface StreamResult {
|
||||
}
|
||||
|
||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||
|
||||
// Type guards for message types
|
||||
|
||||
/**
|
||||
* Check if a message has a toolId property.
|
||||
*/
|
||||
export function hasToolId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolId: string } {
|
||||
return "toolId" in msg && typeof (msg as Record<string, unknown>).toolId === "string";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has an operationId property.
|
||||
*/
|
||||
export function hasOperationId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { operationId: string } {
|
||||
return (
|
||||
"operationId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).operationId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has a toolCallId property.
|
||||
*/
|
||||
export function hasToolCallId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolCallId: string } {
|
||||
return (
|
||||
"toolCallId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message is an operation message type.
|
||||
*/
|
||||
export function isOperationMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & {
|
||||
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
||||
} {
|
||||
return (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tool ID from a message if available.
|
||||
* Checks toolId, operationId, and toolCallId properties.
|
||||
*/
|
||||
export function getToolIdFromMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): string | undefined {
|
||||
const record = msg as Record<string, unknown>;
|
||||
if (typeof record.toolId === "string") return record.toolId;
|
||||
if (typeof record.operationId === "string") return record.operationId;
|
||||
if (typeof record.toolCallId === "string") return record.toolCallId;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import { INITIAL_STREAM_ID } from "../../chat-constants";
|
||||
import { useChatStore } from "../../chat-store";
|
||||
import { toast } from "sonner";
|
||||
import { useChatStream } from "../../useChatStream";
|
||||
import { usePageContext } from "../../usePageContext";
|
||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||
import {
|
||||
getToolIdFromMessage,
|
||||
hasToolId,
|
||||
isOperationMessage,
|
||||
} from "../../chat-types";
|
||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||
import {
|
||||
createUserMessage,
|
||||
@@ -14,6 +20,46 @@ import {
|
||||
processInitialMessages,
|
||||
} from "./helpers";
|
||||
|
||||
/**
|
||||
* Dependencies for creating a stream event dispatcher.
|
||||
* Extracted to allow helper function creation.
|
||||
*/
|
||||
interface DispatcherDeps {
|
||||
setHasTextChunks: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
setStreamingChunks: React.Dispatch<React.SetStateAction<string[]>>;
|
||||
streamingChunksRef: React.MutableRefObject<string[]>;
|
||||
hasResponseRef: React.MutableRefObject<boolean>;
|
||||
setMessages: React.Dispatch<React.SetStateAction<ChatMessageData[]>>;
|
||||
setIsRegionBlockedModalOpen: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
sessionId: string;
|
||||
setIsStreamingInitiated: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
onOperationStarted?: () => void;
|
||||
onActiveTaskStarted: (taskInfo: {
|
||||
taskId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
toolCallId: string;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a stream event dispatcher with the given dependencies.
|
||||
*/
|
||||
function createDispatcher(deps: DispatcherDeps) {
|
||||
return createStreamEventDispatcher({
|
||||
setHasTextChunks: deps.setHasTextChunks,
|
||||
setStreamingChunks: deps.setStreamingChunks,
|
||||
streamingChunksRef: deps.streamingChunksRef,
|
||||
hasResponseRef: deps.hasResponseRef,
|
||||
setMessages: deps.setMessages,
|
||||
setIsRegionBlockedModalOpen: deps.setIsRegionBlockedModalOpen,
|
||||
sessionId: deps.sessionId,
|
||||
setIsStreamingInitiated: deps.setIsStreamingInitiated,
|
||||
onOperationStarted: deps.onOperationStarted,
|
||||
onActiveTaskStarted: deps.onActiveTaskStarted,
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to generate deduplication key for a message
|
||||
function getMessageKey(msg: ChatMessageData): string {
|
||||
if (msg.type === "message") {
|
||||
@@ -24,13 +70,11 @@ function getMessageKey(msg: ChatMessageData): string {
|
||||
} else if (msg.type === "tool_call") {
|
||||
return `toolcall:${msg.toolId}`;
|
||||
} else if (msg.type === "tool_response") {
|
||||
return `toolresponse:${(msg as any).toolId}`;
|
||||
} else if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||
const toolId = hasToolId(msg) ? msg.toolId : "";
|
||||
return `toolresponse:${toolId}`;
|
||||
} else if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg) || "";
|
||||
return `op:${toolId}:${msg.toolName}`;
|
||||
} else {
|
||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||
}
|
||||
@@ -90,7 +134,7 @@ export function useChatContainer({
|
||||
taskId: taskInfo.taskId,
|
||||
operationId: taskInfo.operationId,
|
||||
toolName: taskInfo.toolName,
|
||||
lastMessageId: "0-0", // Redis Stream ID format for full replay
|
||||
lastMessageId: INITIAL_STREAM_ID,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -168,7 +212,7 @@ export function useChatContainer({
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
@@ -221,7 +265,7 @@ export function useChatContainer({
|
||||
},
|
||||
);
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
@@ -259,7 +303,7 @@ export function useChatContainer({
|
||||
return;
|
||||
}
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
@@ -300,7 +344,7 @@ export function useChatContainer({
|
||||
msg.type === "agent_carousel" ||
|
||||
msg.type === "execution_started"
|
||||
) {
|
||||
const toolId = (msg as any).toolId;
|
||||
const toolId = hasToolId(msg) ? msg.toolId : undefined;
|
||||
if (toolId) {
|
||||
ids.add(toolId);
|
||||
}
|
||||
@@ -317,12 +361,8 @@ export function useChatContainer({
|
||||
|
||||
setMessages((prev) => {
|
||||
const filtered = prev.filter((msg) => {
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg);
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false; // Remove - operation completed
|
||||
}
|
||||
@@ -350,12 +390,8 @@ export function useChatContainer({
|
||||
// Filter local messages: remove duplicates and completed operation messages
|
||||
const newLocalMessages = messages.filter((msg) => {
|
||||
// Remove operation messages for completed tools
|
||||
if (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
) {
|
||||
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||
if (isOperationMessage(msg)) {
|
||||
const toolId = getToolIdFromMessage(msg);
|
||||
if (toolId && completedToolIds.has(toolId)) {
|
||||
return false;
|
||||
}
|
||||
@@ -391,7 +427,7 @@ export function useChatContainer({
|
||||
setIsStreamingInitiated(true);
|
||||
hasResponseRef.current = false;
|
||||
|
||||
const dispatcher = createStreamEventDispatcher({
|
||||
const dispatcher = createDispatcher({
|
||||
setHasTextChunks,
|
||||
setStreamingChunks,
|
||||
streamingChunksRef,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { INITIAL_MESSAGE_ID } from "./chat-constants";
|
||||
import type {
|
||||
ActiveStream,
|
||||
StreamChunk,
|
||||
@@ -27,178 +28,118 @@ function notifySubscribers(
|
||||
}
|
||||
}
|
||||
|
||||
export async function executeStream(
|
||||
stream: ActiveStream,
|
||||
message: string,
|
||||
isUserMessage: boolean,
|
||||
context?: { url: string; content: string },
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
const { sessionId, abortController } = stream;
|
||||
|
||||
try {
|
||||
const url = `/api/chat/sessions/${sessionId}/stream`;
|
||||
const body = JSON.stringify({
|
||||
message,
|
||||
is_user_message: isUserMessage,
|
||||
context: context || null,
|
||||
});
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || `HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const rawChunk = JSON.parse(data) as
|
||||
| StreamChunk
|
||||
| VercelStreamChunk;
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
);
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeStream(
|
||||
stream,
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
retryCount + 1,
|
||||
);
|
||||
}
|
||||
|
||||
stream.status = "error";
|
||||
stream.error = err instanceof Error ? err : new Error("Stream failed");
|
||||
notifySubscribers(stream, {
|
||||
type: "error",
|
||||
message: stream.error.message,
|
||||
});
|
||||
}
|
||||
/**
|
||||
* Options for stream execution.
|
||||
*/
|
||||
interface StreamExecutionOptions {
|
||||
/** The active stream state object */
|
||||
stream: ActiveStream;
|
||||
/** Execution mode: 'new' for new stream, 'reconnect' for task reconnection */
|
||||
mode: "new" | "reconnect";
|
||||
/** Message content (required for 'new' mode) */
|
||||
message?: string;
|
||||
/** Whether this is a user message (for 'new' mode) */
|
||||
isUserMessage?: boolean;
|
||||
/** Optional context for the message (for 'new' mode) */
|
||||
context?: { url: string; content: string };
|
||||
/** Task ID (required for 'reconnect' mode) */
|
||||
taskId?: string;
|
||||
/** Last message ID for replay (for 'reconnect' mode) */
|
||||
lastMessageId?: string;
|
||||
/** Current retry count (internal use) */
|
||||
retryCount?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an existing task stream.
|
||||
* Unified stream execution function that handles both new streams and task reconnection.
|
||||
*
|
||||
* This is used when a client wants to resume receiving updates from a
|
||||
* long-running background task. Messages are replayed from the last_message_id
|
||||
* position, allowing clients to catch up on missed events.
|
||||
* For new streams:
|
||||
* - Posts a message to create a new chat stream
|
||||
* - Reads SSE chunks and notifies subscribers
|
||||
*
|
||||
* @param stream - The active stream state
|
||||
* @param taskId - The task ID to reconnect to
|
||||
* @param lastMessageId - The last message ID received (for replay)
|
||||
* @param retryCount - Current retry count
|
||||
* For reconnection:
|
||||
* - Connects to an existing task stream
|
||||
* - Replays messages from lastMessageId position
|
||||
* - Allows resumption of long-running operations
|
||||
*/
|
||||
export async function executeTaskReconnect(
|
||||
stream: ActiveStream,
|
||||
taskId: string,
|
||||
lastMessageId: string = "0",
|
||||
retryCount: number = 0,
|
||||
async function executeStreamInternal(
|
||||
options: StreamExecutionOptions,
|
||||
): Promise<void> {
|
||||
const { abortController } = stream;
|
||||
|
||||
console.info("[SSE-RECONNECT] executeTaskReconnect starting:", {
|
||||
const {
|
||||
stream,
|
||||
mode,
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
lastMessageId = INITIAL_MESSAGE_ID,
|
||||
retryCount = 0,
|
||||
} = options;
|
||||
|
||||
const { sessionId, abortController } = stream;
|
||||
const isReconnect = mode === "reconnect";
|
||||
const logPrefix = isReconnect ? "[SSE-RECONNECT]" : "[StreamExecutor]";
|
||||
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} executeStream starting:`, {
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
console.info("[SSE-RECONNECT] Fetching task stream:", { url });
|
||||
// Build URL and request options based on mode
|
||||
let url: string;
|
||||
let fetchOptions: RequestInit;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
signal: abortController.signal,
|
||||
});
|
||||
if (isReconnect) {
|
||||
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
||||
fetchOptions = {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
signal: abortController.signal,
|
||||
};
|
||||
console.info(`${logPrefix} Fetching task stream:`, { url });
|
||||
} else {
|
||||
url = `/api/chat/sessions/${sessionId}/stream`;
|
||||
fetchOptions = {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message,
|
||||
is_user_message: isUserMessage,
|
||||
context: context || null,
|
||||
}),
|
||||
signal: abortController.signal,
|
||||
};
|
||||
}
|
||||
|
||||
console.info("[SSE-RECONNECT] Task stream response:", {
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
});
|
||||
const response = await fetch(url, fetchOptions);
|
||||
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream response:`, {
|
||||
status: response.status,
|
||||
ok: response.ok,
|
||||
});
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error("[SSE-RECONNECT] Task stream error response:", {
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
// Don't retry on 404 (task not found) or 403 (access denied) - these are permanent errors
|
||||
if (isReconnect) {
|
||||
console.error(`${logPrefix} Task stream error response:`, {
|
||||
status: response.status,
|
||||
errorText,
|
||||
});
|
||||
}
|
||||
// For reconnect: don't retry on 404/403 (permanent errors)
|
||||
const isPermanentError =
|
||||
response.status === 404 || response.status === 403;
|
||||
isReconnect && (response.status === 404 || response.status === 403);
|
||||
const error = new Error(errorText || `HTTP ${response.status}`);
|
||||
(error as Error & { status?: number }).status = response.status;
|
||||
(error as Error & { isPermanent?: boolean }).isPermanent =
|
||||
@@ -210,7 +151,9 @@ export async function executeTaskReconnect(
|
||||
throw new Error("Response body is null");
|
||||
}
|
||||
|
||||
console.info("[SSE-RECONNECT] Task stream connected, reading chunks...");
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream connected, reading chunks...`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
@@ -220,7 +163,11 @@ export async function executeTaskReconnect(
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
console.info("[SSE-RECONNECT] Task stream reader done (connection closed)");
|
||||
if (isReconnect) {
|
||||
console.info(
|
||||
`${logPrefix} Task stream reader done (connection closed)`,
|
||||
);
|
||||
}
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -234,7 +181,9 @@ export async function executeTaskReconnect(
|
||||
const data = parseSSELine(line);
|
||||
if (data !== null) {
|
||||
if (data === "[DONE]") {
|
||||
console.info("[SSE-RECONNECT] Task stream received [DONE] signal");
|
||||
if (isReconnect) {
|
||||
console.info(`${logPrefix} Task stream received [DONE] signal`);
|
||||
}
|
||||
notifySubscribers(stream, { type: "stream_end" });
|
||||
stream.status = "completed";
|
||||
return;
|
||||
@@ -247,9 +196,9 @@ export async function executeTaskReconnect(
|
||||
const chunk = normalizeStreamChunk(rawChunk);
|
||||
if (!chunk) continue;
|
||||
|
||||
// Log first few chunks for debugging
|
||||
if (stream.chunks.length < 3) {
|
||||
console.info("[SSE-RECONNECT] Task stream chunk received:", {
|
||||
// Log first few chunks for debugging (reconnect mode only)
|
||||
if (isReconnect && stream.chunks.length < 3) {
|
||||
console.info(`${logPrefix} Task stream chunk received:`, {
|
||||
type: chunk.type,
|
||||
chunkIndex: stream.chunks.length,
|
||||
});
|
||||
@@ -258,13 +207,19 @@ export async function executeTaskReconnect(
|
||||
notifySubscribers(stream, chunk);
|
||||
|
||||
if (chunk.type === "stream_end") {
|
||||
console.info("[SSE-RECONNECT] Task stream completed via stream_end chunk");
|
||||
if (isReconnect) {
|
||||
console.info(
|
||||
`${logPrefix} Task stream completed via stream_end chunk`,
|
||||
);
|
||||
}
|
||||
stream.status = "completed";
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk.type === "error") {
|
||||
console.error("[SSE-RECONNECT] Task stream error chunk:", chunk);
|
||||
if (isReconnect) {
|
||||
console.error(`${logPrefix} Task stream error chunk:`, chunk);
|
||||
}
|
||||
stream.status = "error";
|
||||
stream.error = new Error(
|
||||
chunk.message || chunk.content || "Stream error",
|
||||
@@ -272,10 +227,7 @@ export async function executeTaskReconnect(
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn(
|
||||
"[StreamExecutor] Failed to parse task reconnect SSE chunk:",
|
||||
err,
|
||||
);
|
||||
console.warn(`${logPrefix} Failed to parse SSE chunk:`, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -295,30 +247,76 @@ export async function executeTaskReconnect(
|
||||
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
`${logPrefix} Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
return executeTaskReconnect(
|
||||
stream,
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount + 1,
|
||||
);
|
||||
return executeStreamInternal({
|
||||
...options,
|
||||
retryCount: retryCount + 1,
|
||||
});
|
||||
}
|
||||
|
||||
// Log permanent errors differently for debugging
|
||||
if (isPermanentError) {
|
||||
console.log(
|
||||
`[StreamExecutor] Task reconnect failed permanently (task not found or access denied): ${(err as Error).message}`,
|
||||
`${logPrefix} Stream failed permanently (task not found or access denied): ${(err as Error).message}`,
|
||||
);
|
||||
}
|
||||
|
||||
stream.status = "error";
|
||||
stream.error =
|
||||
err instanceof Error ? err : new Error("Task reconnect failed");
|
||||
stream.error = err instanceof Error ? err : new Error("Stream failed");
|
||||
notifySubscribers(stream, {
|
||||
type: "error",
|
||||
message: stream.error.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a new chat stream.
|
||||
*
|
||||
* Posts a message to create a new stream and reads SSE responses.
|
||||
*/
|
||||
export async function executeStream(
|
||||
stream: ActiveStream,
|
||||
message: string,
|
||||
isUserMessage: boolean,
|
||||
context?: { url: string; content: string },
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
return executeStreamInternal({
|
||||
stream,
|
||||
mode: "new",
|
||||
message,
|
||||
isUserMessage,
|
||||
context,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconnect to an existing task stream.
|
||||
*
|
||||
* This is used when a client wants to resume receiving updates from a
|
||||
* long-running background task. Messages are replayed from the last_message_id
|
||||
* position, allowing clients to catch up on missed events.
|
||||
*
|
||||
* @param stream - The active stream state
|
||||
* @param taskId - The task ID to reconnect to
|
||||
* @param lastMessageId - The last message ID received (for replay)
|
||||
* @param retryCount - Current retry count
|
||||
*/
|
||||
export async function executeTaskReconnect(
|
||||
stream: ActiveStream,
|
||||
taskId: string,
|
||||
lastMessageId: string = INITIAL_MESSAGE_ID,
|
||||
retryCount: number = 0,
|
||||
): Promise<void> {
|
||||
return executeStreamInternal({
|
||||
stream,
|
||||
mode: "reconnect",
|
||||
taskId,
|
||||
lastMessageId,
|
||||
retryCount,
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user