simplify and ensure agents are added to store

This commit is contained in:
Swifty
2026-02-02 16:15:38 +01:00
parent ef3fab57fd
commit 070d56166c
10 changed files with 723 additions and 494 deletions

View File

@@ -20,17 +20,12 @@ from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
# Stream configuration
COMPLETION_STREAM = "chat:completions"
CONSUMER_GROUP = "chat_consumers"
STREAM_MAX_LENGTH = 10000
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
@@ -69,17 +64,20 @@ class ChatCompletionConsumer:
try:
redis = await get_redis_async()
await redis.xgroup_create(
COMPLETION_STREAM,
CONSUMER_GROUP,
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{CONSUMER_GROUP}' on stream '{COMPLETION_STREAM}'"
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(f"Consumer group '{CONSUMER_GROUP}' already exists")
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
@@ -134,9 +132,9 @@ class ChatCompletionConsumer:
while self._running:
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=CONSUMER_GROUP,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={COMPLETION_STREAM: ">"},
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
@@ -161,7 +159,9 @@ class ChatCompletionConsumer:
# Acknowledge the message
await redis.xack(
COMPLETION_STREAM, CONSUMER_GROUP, entry_id
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
@@ -237,72 +237,8 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
# Publish result to stream registry
result_output = message.result if message.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database using our Prisma client
result_str = (
message.result
if isinstance(message.result, str)
else (
orjson.dumps(message.result).decode("utf-8")
if message.result
else '{"status": "completed"}'
)
)
try:
prisma = await self._ensure_prisma()
await prisma.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": result_str},
)
logger.info(
f"[COMPLETION] Updated tool message for session {task.session_id}"
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
)
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
@@ -310,47 +246,8 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error using our Prisma client
error_response = ErrorResponse(
message=error_msg,
error=message.error,
)
try:
prisma = await self._ensure_prisma()
await prisma.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": error_response.model_dump_json()},
)
logger.info(
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}", exc_info=True
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}"
)
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
@@ -399,8 +296,8 @@ async def publish_operation_complete(
redis = await get_redis_async()
await redis.xadd(
COMPLETION_STREAM,
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=STREAM_MAX_LENGTH,
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,255 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
def serialize_result(result: dict | str | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize (dict, string, or None)
Returns:
JSON string representation of the result
"""
if isinstance(result, str):
return result
if result:
return orjson.dumps(result).decode("utf-8")
return '{"status": "completed"}'
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning(
"[COMPLETION] Cannot save agent: no user_id in task"
)
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": agent_json, # Include the JSON so user can retry
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output
result_output = result if result else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
result_str = serialize_result(result)
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
await prisma_client.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": result_str},
)
logger.info(
f"[COMPLETION] Updated tool message for session {task.session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
await prisma_client.chatmessage.update_many(
where={
"sessionId": task.session_id,
"toolCallId": task.tool_call_id,
},
data={"content": error_response.model_dump_json()},
)
logger.info(
f"[COMPLETION] Updated tool message with error for session {task.session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -50,9 +50,37 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=1000,
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
task_pubsub_prefix: str = Field(
default="chat:task:pubsub:",
description="Prefix for task pub/sub channel names",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",

View File

@@ -5,7 +5,6 @@ import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
import orjson
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
@@ -15,6 +14,7 @@ from backend.util.exceptions import NotFoundError
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
@@ -704,81 +704,9 @@ async def complete_operation(
)
if request.success:
# Publish result to stream registry
from .response_model import StreamToolOutputAvailable
result_output = request.result if request.result else {"status": "completed"}
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=(
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
)
# Update pending operation in database
from . import service as svc
result_str = (
request.result
if isinstance(request.result, str)
else (
orjson.dumps(request.result).decode("utf-8")
if request.result
else '{"status": "completed"}'
)
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=result_str,
)
# Generate LLM continuation with streaming
await svc._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await svc._mark_operation_completed(task.tool_call_id)
await process_operation_success(task, request.result)
else:
# Publish error to stream registry
from .response_model import StreamError
error_msg = request.error or "Operation failed"
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Send finish event to end the stream
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
from . import service as svc
from .tools.models import ErrorResponse
error_response = ErrorResponse(
message=error_msg,
error=request.error,
)
await svc._update_pending_operation(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
result=error_response.model_dump_json(),
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await svc._mark_operation_completed(task.tool_call_id)
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}

View File

@@ -31,6 +31,9 @@ from .response_model import StreamBaseResponse, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
@dataclass
class ActiveTask:
@@ -47,34 +50,24 @@ class ActiveTask:
asyncio_task: asyncio.Task | None = None
# Redis key patterns
TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata
TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages
TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping
TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for real-time delivery
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{TASK_META_PREFIX}{task_id}"
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{TASK_STREAM_PREFIX}{task_id}"
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{TASK_OP_PREFIX}{operation_id}"
return f"{config.task_op_prefix}{operation_id}"
def _get_task_pubsub_channel(task_id: str) -> str:
"""Get Redis pub/sub channel for task real-time delivery."""
return f"{TASK_PUBSUB_PREFIX}{task_id}"
return f"{config.task_pubsub_prefix}{task_id}"
async def create_task(
@@ -466,7 +459,9 @@ async def get_active_task_for_session(
tasks_checked = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{TASK_META_PREFIX}*", count=100)
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1

View File

@@ -0,0 +1,16 @@
/**
* Constants for the chat system.
*
* Centralizes magic strings and values used across chat components.
*/
// LocalStorage keys
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
// Redis Stream IDs
export const INITIAL_MESSAGE_ID = "0";
export const INITIAL_STREAM_ID = "0-0";
// TTL values (in milliseconds)
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour

View File

@@ -1,6 +1,12 @@
"use client";
import { create } from "zustand";
import {
ACTIVE_TASK_TTL_MS,
COMPLETED_STREAM_TTL_MS,
INITIAL_STREAM_ID,
STORAGE_KEY_ACTIVE_TASKS,
} from "./chat-constants";
import type {
ActiveStream,
StreamChunk,
@@ -10,10 +16,6 @@ import type {
} from "./chat-types";
import { executeStream, executeTaskReconnect } from "./stream-executor";
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
const ACTIVE_TASKS_STORAGE_KEY = "chat_active_tasks";
const TASK_TTL = 60 * 60 * 1000; // 1 hour - tasks expire after this
/**
* Tracks active task info for SSE reconnection.
* When a long-running operation starts, we store this so clients can reconnect
@@ -32,14 +34,14 @@ export interface ActiveTaskInfo {
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
if (typeof window === "undefined") return new Map();
try {
const stored = localStorage.getItem(ACTIVE_TASKS_STORAGE_KEY);
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
if (!stored) return new Map();
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
const now = Date.now();
const tasks = new Map<string, ActiveTaskInfo>();
// Filter out expired tasks
for (const [sessionId, task] of Object.entries(parsed)) {
if (now - task.startedAt < TASK_TTL) {
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
tasks.set(sessionId, task);
}
}
@@ -57,7 +59,7 @@ function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
for (const [sessionId, task] of tasks) {
obj[sessionId] = task;
}
localStorage.setItem(ACTIVE_TASKS_STORAGE_KEY, JSON.stringify(obj));
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
} catch {
// Ignore storage errors
}
@@ -135,13 +137,73 @@ function cleanupExpiredStreams(
const now = Date.now();
const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
cleaned.delete(sessionId);
}
}
return cleaned;
}
/**
* Clean up an existing stream for a session and move it to completed streams.
* Returns updated maps for both active and completed streams.
*/
function cleanupExistingStream(
sessionId: string,
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
callbacks: Set<StreamCompleteCallback>,
): {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
} {
const newActiveStreams = new Map(activeStreams);
let newCompletedStreams = new Map(completedStreams);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming" ? "completed" : existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
return { activeStreams: newActiveStreams, completedStreams: newCompletedStreams };
}
/**
* Create a new active stream with initial state.
*/
function createActiveStream(
sessionId: string,
onChunk?: (chunk: StreamChunk) => void,
): ActiveStream {
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
return {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
}
export const useChatStore = create<ChatStore>((set, get) => ({
activeStreams: new Map(),
completedStreams: new Map(),
@@ -157,45 +219,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk,
) {
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
// Clean up any existing stream for this session
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
cleanupExistingStream(
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
@@ -388,7 +424,7 @@ export const useChatStore = create<ChatStore>((set, get) => ({
reconnectToTask: async function reconnectToTask(
sessionId,
taskId,
lastMessageId = "0-0", // Redis Stream ID format
lastMessageId = INITIAL_STREAM_ID,
onChunk,
) {
console.info("[SSE-RECONNECT] reconnectToTask called:", {
@@ -398,43 +434,19 @@ export const useChatStore = create<ChatStore>((set, get) => ({
});
const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
const { activeStreams: newActiveStreams, completedStreams: newCompletedStreams } =
cleanupExistingStream(
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream for reconnection
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,

View File

@@ -94,3 +94,67 @@ export interface StreamResult {
}
export type StreamCompleteCallback = (sessionId: string) => void;
// Type guards for message types
/**
* Check if a message has a toolId property.
*/
export function hasToolId<T extends { type: string }>(
msg: T,
): msg is T & { toolId: string } {
return "toolId" in msg && typeof (msg as Record<string, unknown>).toolId === "string";
}
/**
* Check if a message has an operationId property.
*/
export function hasOperationId<T extends { type: string }>(
msg: T,
): msg is T & { operationId: string } {
return (
"operationId" in msg &&
typeof (msg as Record<string, unknown>).operationId === "string"
);
}
/**
* Check if a message has a toolCallId property.
*/
export function hasToolCallId<T extends { type: string }>(
msg: T,
): msg is T & { toolCallId: string } {
return (
"toolCallId" in msg &&
typeof (msg as Record<string, unknown>).toolCallId === "string"
);
}
/**
* Check if a message is an operation message type.
*/
export function isOperationMessage<T extends { type: string }>(
msg: T,
): msg is T & {
type: "operation_started" | "operation_pending" | "operation_in_progress";
} {
return (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
);
}
/**
* Get the tool ID from a message if available.
* Checks toolId, operationId, and toolCallId properties.
*/
export function getToolIdFromMessage<T extends { type: string }>(
msg: T,
): string | undefined {
const record = msg as Record<string, unknown>;
if (typeof record.toolId === "string") return record.toolId;
if (typeof record.operationId === "string") return record.operationId;
if (typeof record.toolCallId === "string") return record.toolCallId;
return undefined;
}

View File

@@ -1,10 +1,16 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import { useEffect, useMemo, useRef, useState } from "react";
import { INITIAL_STREAM_ID } from "../../chat-constants";
import { useChatStore } from "../../chat-store";
import { toast } from "sonner";
import { useChatStream } from "../../useChatStream";
import { usePageContext } from "../../usePageContext";
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
import {
getToolIdFromMessage,
hasToolId,
isOperationMessage,
} from "../../chat-types";
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
import {
createUserMessage,
@@ -14,6 +20,46 @@ import {
processInitialMessages,
} from "./helpers";
/**
* Dependencies for creating a stream event dispatcher.
* Extracted to allow helper function creation.
*/
interface DispatcherDeps {
setHasTextChunks: React.Dispatch<React.SetStateAction<boolean>>;
setStreamingChunks: React.Dispatch<React.SetStateAction<string[]>>;
streamingChunksRef: React.MutableRefObject<string[]>;
hasResponseRef: React.MutableRefObject<boolean>;
setMessages: React.Dispatch<React.SetStateAction<ChatMessageData[]>>;
setIsRegionBlockedModalOpen: React.Dispatch<React.SetStateAction<boolean>>;
sessionId: string;
setIsStreamingInitiated: React.Dispatch<React.SetStateAction<boolean>>;
onOperationStarted?: () => void;
onActiveTaskStarted: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
}
/**
* Create a stream event dispatcher with the given dependencies.
*/
function createDispatcher(deps: DispatcherDeps) {
return createStreamEventDispatcher({
setHasTextChunks: deps.setHasTextChunks,
setStreamingChunks: deps.setStreamingChunks,
streamingChunksRef: deps.streamingChunksRef,
hasResponseRef: deps.hasResponseRef,
setMessages: deps.setMessages,
setIsRegionBlockedModalOpen: deps.setIsRegionBlockedModalOpen,
sessionId: deps.sessionId,
setIsStreamingInitiated: deps.setIsStreamingInitiated,
onOperationStarted: deps.onOperationStarted,
onActiveTaskStarted: deps.onActiveTaskStarted,
});
}
// Helper to generate deduplication key for a message
function getMessageKey(msg: ChatMessageData): string {
if (msg.type === "message") {
@@ -24,13 +70,11 @@ function getMessageKey(msg: ChatMessageData): string {
} else if (msg.type === "tool_call") {
return `toolcall:${msg.toolId}`;
} else if (msg.type === "tool_response") {
return `toolresponse:${(msg as any).toolId}`;
} else if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
const toolId = hasToolId(msg) ? msg.toolId : "";
return `toolresponse:${toolId}`;
} else if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg) || "";
return `op:${toolId}:${msg.toolName}`;
} else {
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
}
@@ -90,7 +134,7 @@ export function useChatContainer({
taskId: taskInfo.taskId,
operationId: taskInfo.operationId,
toolName: taskInfo.toolName,
lastMessageId: "0-0", // Redis Stream ID format for full replay
lastMessageId: INITIAL_STREAM_ID,
});
}
@@ -168,7 +212,7 @@ export function useChatContainer({
},
);
const dispatcher = createStreamEventDispatcher({
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
@@ -221,7 +265,7 @@ export function useChatContainer({
},
);
const dispatcher = createStreamEventDispatcher({
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
@@ -259,7 +303,7 @@ export function useChatContainer({
return;
}
const dispatcher = createStreamEventDispatcher({
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
@@ -300,7 +344,7 @@ export function useChatContainer({
msg.type === "agent_carousel" ||
msg.type === "execution_started"
) {
const toolId = (msg as any).toolId;
const toolId = hasToolId(msg) ? msg.toolId : undefined;
if (toolId) {
ids.add(toolId);
}
@@ -317,12 +361,8 @@ export function useChatContainer({
setMessages((prev) => {
const filtered = prev.filter((msg) => {
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed
}
@@ -350,12 +390,8 @@ export function useChatContainer({
// Filter local messages: remove duplicates and completed operation messages
const newLocalMessages = messages.filter((msg) => {
// Remove operation messages for completed tools
if (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg);
if (toolId && completedToolIds.has(toolId)) {
return false;
}
@@ -391,7 +427,7 @@ export function useChatContainer({
setIsStreamingInitiated(true);
hasResponseRef.current = false;
const dispatcher = createStreamEventDispatcher({
const dispatcher = createDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,

View File

@@ -1,3 +1,4 @@
import { INITIAL_MESSAGE_ID } from "./chat-constants";
import type {
ActiveStream,
StreamChunk,
@@ -27,178 +28,118 @@ function notifySubscribers(
}
}
export async function executeStream(
stream: ActiveStream,
message: string,
isUserMessage: boolean,
context?: { url: string; content: string },
retryCount: number = 0,
): Promise<void> {
const { sessionId, abortController } = stream;
try {
const url = `/api/chat/sessions/${sessionId}/stream`;
const body = JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
});
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body,
signal: abortController.signal,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(errorText || `HTTP ${response.status}`);
}
if (!response.body) {
throw new Error("Response body is null");
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
const data = parseSSELine(line);
if (data !== null) {
if (data === "[DONE]") {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
try {
const rawChunk = JSON.parse(data) as
| StreamChunk
| VercelStreamChunk;
const chunk = normalizeStreamChunk(rawChunk);
if (!chunk) continue;
notifySubscribers(stream, chunk);
if (chunk.type === "stream_end") {
stream.status = "completed";
return;
}
if (chunk.type === "error") {
stream.status = "error";
stream.error = new Error(
chunk.message || chunk.content || "Stream error",
);
return;
}
} catch (err) {
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
}
}
}
}
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
}
if (retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeStream(
stream,
message,
isUserMessage,
context,
retryCount + 1,
);
}
stream.status = "error";
stream.error = err instanceof Error ? err : new Error("Stream failed");
notifySubscribers(stream, {
type: "error",
message: stream.error.message,
});
}
/**
* Options for stream execution.
*/
interface StreamExecutionOptions {
/** The active stream state object */
stream: ActiveStream;
/** Execution mode: 'new' for new stream, 'reconnect' for task reconnection */
mode: "new" | "reconnect";
/** Message content (required for 'new' mode) */
message?: string;
/** Whether this is a user message (for 'new' mode) */
isUserMessage?: boolean;
/** Optional context for the message (for 'new' mode) */
context?: { url: string; content: string };
/** Task ID (required for 'reconnect' mode) */
taskId?: string;
/** Last message ID for replay (for 'reconnect' mode) */
lastMessageId?: string;
/** Current retry count (internal use) */
retryCount?: number;
}
/**
* Reconnect to an existing task stream.
* Unified stream execution function that handles both new streams and task reconnection.
*
* This is used when a client wants to resume receiving updates from a
* long-running background task. Messages are replayed from the last_message_id
* position, allowing clients to catch up on missed events.
* For new streams:
* - Posts a message to create a new chat stream
* - Reads SSE chunks and notifies subscribers
*
* @param stream - The active stream state
* @param taskId - The task ID to reconnect to
* @param lastMessageId - The last message ID received (for replay)
* @param retryCount - Current retry count
* For reconnection:
* - Connects to an existing task stream
* - Replays messages from lastMessageId position
* - Allows resumption of long-running operations
*/
export async function executeTaskReconnect(
stream: ActiveStream,
taskId: string,
lastMessageId: string = "0",
retryCount: number = 0,
async function executeStreamInternal(
options: StreamExecutionOptions,
): Promise<void> {
const { abortController } = stream;
console.info("[SSE-RECONNECT] executeTaskReconnect starting:", {
const {
stream,
mode,
message,
isUserMessage,
context,
taskId,
lastMessageId,
retryCount,
});
lastMessageId = INITIAL_MESSAGE_ID,
retryCount = 0,
} = options;
const { sessionId, abortController } = stream;
const isReconnect = mode === "reconnect";
const logPrefix = isReconnect ? "[SSE-RECONNECT]" : "[StreamExecutor]";
if (isReconnect) {
console.info(`${logPrefix} executeStream starting:`, {
taskId,
lastMessageId,
retryCount,
});
}
try {
const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
console.info("[SSE-RECONNECT] Fetching task stream:", { url });
// Build URL and request options based on mode
let url: string;
let fetchOptions: RequestInit;
const response = await fetch(url, {
method: "GET",
headers: {
Accept: "text/event-stream",
},
signal: abortController.signal,
});
if (isReconnect) {
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
fetchOptions = {
method: "GET",
headers: {
Accept: "text/event-stream",
},
signal: abortController.signal,
};
console.info(`${logPrefix} Fetching task stream:`, { url });
} else {
url = `/api/chat/sessions/${sessionId}/stream`;
fetchOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body: JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
}),
signal: abortController.signal,
};
}
console.info("[SSE-RECONNECT] Task stream response:", {
status: response.status,
ok: response.ok,
});
const response = await fetch(url, fetchOptions);
if (isReconnect) {
console.info(`${logPrefix} Task stream response:`, {
status: response.status,
ok: response.ok,
});
}
if (!response.ok) {
const errorText = await response.text();
console.error("[SSE-RECONNECT] Task stream error response:", {
status: response.status,
errorText,
});
// Don't retry on 404 (task not found) or 403 (access denied) - these are permanent errors
if (isReconnect) {
console.error(`${logPrefix} Task stream error response:`, {
status: response.status,
errorText,
});
}
// For reconnect: don't retry on 404/403 (permanent errors)
const isPermanentError =
response.status === 404 || response.status === 403;
isReconnect && (response.status === 404 || response.status === 403);
const error = new Error(errorText || `HTTP ${response.status}`);
(error as Error & { status?: number }).status = response.status;
(error as Error & { isPermanent?: boolean }).isPermanent =
@@ -210,7 +151,9 @@ export async function executeTaskReconnect(
throw new Error("Response body is null");
}
console.info("[SSE-RECONNECT] Task stream connected, reading chunks...");
if (isReconnect) {
console.info(`${logPrefix} Task stream connected, reading chunks...`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
@@ -220,7 +163,11 @@ export async function executeTaskReconnect(
const { done, value } = await reader.read();
if (done) {
console.info("[SSE-RECONNECT] Task stream reader done (connection closed)");
if (isReconnect) {
console.info(
`${logPrefix} Task stream reader done (connection closed)`,
);
}
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -234,7 +181,9 @@ export async function executeTaskReconnect(
const data = parseSSELine(line);
if (data !== null) {
if (data === "[DONE]") {
console.info("[SSE-RECONNECT] Task stream received [DONE] signal");
if (isReconnect) {
console.info(`${logPrefix} Task stream received [DONE] signal`);
}
notifySubscribers(stream, { type: "stream_end" });
stream.status = "completed";
return;
@@ -247,9 +196,9 @@ export async function executeTaskReconnect(
const chunk = normalizeStreamChunk(rawChunk);
if (!chunk) continue;
// Log first few chunks for debugging
if (stream.chunks.length < 3) {
console.info("[SSE-RECONNECT] Task stream chunk received:", {
// Log first few chunks for debugging (reconnect mode only)
if (isReconnect && stream.chunks.length < 3) {
console.info(`${logPrefix} Task stream chunk received:`, {
type: chunk.type,
chunkIndex: stream.chunks.length,
});
@@ -258,13 +207,19 @@ export async function executeTaskReconnect(
notifySubscribers(stream, chunk);
if (chunk.type === "stream_end") {
console.info("[SSE-RECONNECT] Task stream completed via stream_end chunk");
if (isReconnect) {
console.info(
`${logPrefix} Task stream completed via stream_end chunk`,
);
}
stream.status = "completed";
return;
}
if (chunk.type === "error") {
console.error("[SSE-RECONNECT] Task stream error chunk:", chunk);
if (isReconnect) {
console.error(`${logPrefix} Task stream error chunk:`, chunk);
}
stream.status = "error";
stream.error = new Error(
chunk.message || chunk.content || "Stream error",
@@ -272,10 +227,7 @@ export async function executeTaskReconnect(
return;
}
} catch (err) {
console.warn(
"[StreamExecutor] Failed to parse task reconnect SSE chunk:",
err,
);
console.warn(`${logPrefix} Failed to parse SSE chunk:`, err);
}
}
}
@@ -295,30 +247,76 @@ export async function executeTaskReconnect(
if (!isPermanentError && retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
`${logPrefix} Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeTaskReconnect(
stream,
taskId,
lastMessageId,
retryCount + 1,
);
return executeStreamInternal({
...options,
retryCount: retryCount + 1,
});
}
// Log permanent errors differently for debugging
if (isPermanentError) {
console.log(
`[StreamExecutor] Task reconnect failed permanently (task not found or access denied): ${(err as Error).message}`,
`${logPrefix} Stream failed permanently (task not found or access denied): ${(err as Error).message}`,
);
}
stream.status = "error";
stream.error =
err instanceof Error ? err : new Error("Task reconnect failed");
stream.error = err instanceof Error ? err : new Error("Stream failed");
notifySubscribers(stream, {
type: "error",
message: stream.error.message,
});
}
}
/**
* Execute a new chat stream.
*
* Posts a message to create a new stream and reads SSE responses.
*/
export async function executeStream(
stream: ActiveStream,
message: string,
isUserMessage: boolean,
context?: { url: string; content: string },
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "new",
message,
isUserMessage,
context,
retryCount,
});
}
/**
* Reconnect to an existing task stream.
*
* This is used when a client wants to resume receiving updates from a
* long-running background task. Messages are replayed from the last_message_id
* position, allowing clients to catch up on missed events.
*
* @param stream - The active stream state
* @param taskId - The task ID to reconnect to
* @param lastMessageId - The last message ID received (for replay)
* @param retryCount - Current retry count
*/
export async function executeTaskReconnect(
stream: ActiveStream,
taskId: string,
lastMessageId: string = INITIAL_MESSAGE_ID,
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "reconnect",
taskId,
lastMessageId,
retryCount,
});
}