Compare commits

..

1 Commits

Author SHA1 Message Date
Bentlybro
6b1f0df58c fix(backend): Clean up orphaned schedules without schedule_id
Old scheduled jobs created before schedule_id was added to
GraphExecutionJobArgs have schedule_id=None. When these fail
validation, _handle_graph_validation_error could not unschedule
them, causing them to fire repeatedly and generate ~60K+ Sentry
errors (AUTOGPT-SERVER-6W2 and AUTOGPT-SERVER-6W3).

Fix: Add _cleanup_old_schedules_without_id() which finds schedules
for the graph but only removes those with schedule_id=None (legacy
jobs). This preserves any valid newer schedules the user may have
created, unlike the broader _cleanup_orphaned_schedules_for_graph()
which removes all schedules for a graph.
2026-02-02 14:15:25 +00:00
72 changed files with 1248 additions and 6019 deletions

1
.gitignore vendored
View File

@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json .claude/settings.local.json
CLAUDE.local.md CLAUDE.local.md
/autogpt_platform/backend/logs /autogpt_platform/backend/logs
.next

View File

@@ -1,368 +0,0 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -1,344 +0,0 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
) )
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration # Langfuse Prompt Management Configuration
# Note: Langfuse credentials are in Settings().secrets (settings.py) # Note: Langfuse credentials are in Settings().secrets (settings.py)
langfuse_prompt_name: str = Field( langfuse_prompt_name: str = Field(
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1" v = "https://openrouter.ai/api/v1"
return v return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts # Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = { PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md", "default": "prompts/chat_system.md",

View File

@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID") messageId: str = Field(..., description="Unique message ID")
taskId: str | None = Field(
default=None,
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
class StreamFinish(StreamBaseResponse): class StreamFinish(StreamBaseResponse):

View File

@@ -1,23 +1,19 @@
"""Chat API routes for chat session management and streaming via SSE.""" """Chat API routes for chat session management and streaming via SSE."""
import logging import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Annotated from typing import Annotated
from autogpt_libs import auth from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi import APIRouter, Depends, Query, Security
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from . import service as chat_service from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
config = ChatConfig() config = ChatConfig()
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
user_id: str | None user_id: str | None
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel): class SessionDetailResponse(BaseModel):
"""Response model providing complete details for a chat session, including messages.""" """Response model providing complete details for a chat session, including messages."""
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
updated_at: str updated_at: str
user_id: str | None user_id: str | None
messages: list[dict] messages: list[dict]
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
class SessionSummaryResponse(BaseModel): class SessionSummaryResponse(BaseModel):
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
total: int total: int
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ========== # ========== Routes ==========
@@ -188,14 +166,13 @@ async def get_session(
Retrieve the details of a specific chat session. Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns the task_id for reconnection.
Args: Args:
session_id: The unique identifier for the desired chat session. session_id: The unique identifier for the desired chat session.
user_id: The optional authenticated user ID, or None for anonymous access. user_id: The optional authenticated user ID, or None for anonymous access.
Returns: Returns:
SessionDetailResponse: Details for the requested session, including active_stream info if applicable. SessionDetailResponse: Details for the requested session, or None if not found.
""" """
session = await get_chat_session(session_id, user_id) session = await get_chat_session(session_id, user_id)
@@ -203,28 +180,11 @@ async def get_session(
raise NotFoundError(f"Session {session_id} not found.") raise NotFoundError(f"Session {session_id} not found.")
messages = [message.model_dump() for message in session.messages] messages = [message.model_dump() for message in session.messages]
logger.info(
# Check if there's an active stream for this session f"Returning session {session_id}: "
active_stream_info = None f"message_count={len(messages)}, "
active_task, last_message_id = await stream_registry.get_active_task_for_session( f"roles={[m.get('role') for m in messages]}"
session_id, user_id
) )
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse( return SessionDetailResponse(
id=session.session_id, id=session.session_id,
@@ -232,7 +192,6 @@ async def get_session(
updated_at=session.updated_at.isoformat(), updated_at=session.updated_at.isoformat(),
user_id=session.user_id or None, user_id=session.user_id or None,
messages=messages, messages=messages,
active_stream=active_stream_info,
) )
@@ -252,112 +211,49 @@ async def stream_chat_post(
- Tool call UI elements (if invoked) - Tool call UI elements (if invoked)
- Tool execution results - Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args: Args:
session_id: The chat session identifier to associate with the streamed messages. session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context. request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID. user_id: Optional authenticated user ID.
Returns: Returns:
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event StreamingResponse: SSE-formatted response chunks.
containing the task_id for reconnection.
""" """
import asyncio
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
operation_id=operation_id,
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
session_id,
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
):
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
)
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
# SSE endpoint that subscribes to the task's stream
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
subscriber_queue = None chunk_count = 0
try: first_chunk_type: str | None = None
# Subscribe to the task stream (this replays existing messages + live updates) async for chunk in chat_service.stream_chat_completion(
subscriber_queue = await stream_registry.subscribe_to_task( session_id,
task_id=task_id, request.message,
user_id=user_id, is_user_message=request.is_user_message,
last_message_id="0-0", # Get all messages from the beginning user_id=user_id,
) session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
if subscriber_queue is None: ):
yield StreamFinish().to_sse() if chunk_count < 3:
yield "data: [DONE]\n\n" logger.info(
return "Chat stream chunk",
extra={
# Read from the subscriber queue and yield to SSE "session_id": session_id,
while True: "chunk_type": str(chunk.type),
try: },
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) )
yield chunk.to_sse() if not first_chunk_type:
first_chunk_type = str(chunk.type)
# Check for finish signal chunk_count += 1
if isinstance(chunk, StreamFinish): yield chunk.to_sse()
break logger.info(
except asyncio.TimeoutError: "Chat stream completed",
# Send heartbeat to keep connection alive extra={
yield StreamHeartbeat().to_sse() "session_id": session_id,
"chunk_count": chunk_count,
except GeneratorExit: "first_chunk_type": first_chunk_type,
pass # Client disconnected - background task continues },
except Exception as e: )
logger.error(f"Error in SSE stream for task {task_id}: {e}") # AI SDK protocol termination
finally: yield "data: [DONE]\n\n"
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),
@@ -470,251 +366,6 @@ async def session_assign_user(
return {"status": "ok"} return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@router.get("/config/ttl", status_code=200)
async def get_ttl_config() -> dict:
"""
Get the stream TTL configuration.
Returns the Time-To-Live settings for chat streams, which determines
how long clients can reconnect to an active stream.
Returns:
dict: TTL configuration with seconds and milliseconds values.
"""
return {
"stream_ttl_seconds": config.stream_ttl,
"stream_ttl_ms": config.stream_ttl * 1000,
}
# ========== Health Check ========== # ========== Health Check ==========

File diff suppressed because it is too large Load Diff

View File

@@ -1,704 +0,0 @@
"""Stream registry for managing reconnectable SSE streams.
This module provides a registry for tracking active streaming tasks and their
messages. It uses Redis for all state management (no in-memory state), making
pods stateless and horizontally scalable.
Architecture:
- Redis Stream: Persists all messages for replay and real-time delivery
- Redis Hash: Task metadata (status, session_id, etc.)
Subscribers:
1. Replay missed messages from Redis Stream (XREAD)
2. Listen for live updates via blocking XREAD
3. No in-memory state required on the subscribing pod
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import orjson
from backend.data.redis_client import get_redis_async
from .config import ChatConfig
from .response_model import StreamBaseResponse, StreamError, StreamFinish
logger = logging.getLogger(__name__)
config = ChatConfig()
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
_local_tasks: dict[str, asyncio.Task] = {}
# Track listener tasks per subscriber queue for cleanup
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
# Timeout for putting chunks into subscriber queues (seconds)
# If the queue is full and doesn't drain within this time, send an overflow error
QUEUE_PUT_TIMEOUT = 5.0
# Lua script for atomic compare-and-swap status update (idempotent completion)
# Returns 1 if status was updated, 0 if already completed/failed
COMPLETE_TASK_SCRIPT = """
local current = redis.call("HGET", KEYS[1], "status")
if current == "running" then
redis.call("HSET", KEYS[1], "status", ARGV[1])
return 1
end
return 0
"""
@dataclass
class ActiveTask:
"""Represents an active streaming task (metadata only, no in-memory queues)."""
task_id: str
session_id: str
user_id: str | None
tool_call_id: str
tool_name: str
operation_id: str
status: Literal["running", "completed", "failed"] = "running"
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
asyncio_task: asyncio.Task | None = None
def _get_task_meta_key(task_id: str) -> str:
"""Get Redis key for task metadata."""
return f"{config.task_meta_prefix}{task_id}"
def _get_task_stream_key(task_id: str) -> str:
"""Get Redis key for task message stream."""
return f"{config.task_stream_prefix}{task_id}"
def _get_operation_mapping_key(operation_id: str) -> str:
"""Get Redis key for operation_id to task_id mapping."""
return f"{config.task_op_prefix}{operation_id}"
async def create_task(
task_id: str,
session_id: str,
user_id: str | None,
tool_call_id: str,
tool_name: str,
operation_id: str,
) -> ActiveTask:
"""Create a new streaming task in Redis.
Args:
task_id: Unique identifier for the task
session_id: Chat session ID
user_id: User ID (may be None for anonymous)
tool_call_id: Tool call ID from the LLM
tool_name: Name of the tool being executed
operation_id: Operation ID for webhook callbacks
Returns:
The created ActiveTask instance (metadata only)
"""
task = ActiveTask(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# Store metadata in Redis
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
op_key = _get_operation_mapping_key(operation_id)
await redis.hset( # type: ignore[misc]
meta_key,
mapping={
"task_id": task_id,
"session_id": session_id,
"user_id": user_id or "",
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"operation_id": operation_id,
"status": task.status,
"created_at": task.created_at.isoformat(),
},
)
await redis.expire(meta_key, config.stream_ttl)
# Create operation_id -> task_id mapping for webhook lookups
await redis.set(op_key, task_id, ex=config.stream_ttl)
logger.debug(f"Created task {task_id} for session {session_id}")
return task
async def publish_chunk(
task_id: str,
chunk: StreamBaseResponse,
) -> str:
"""Publish a chunk to Redis Stream.
All delivery is via Redis Streams - no in-memory state.
Args:
task_id: Task ID to publish to
chunk: The stream response chunk to publish
Returns:
The Redis Stream message ID
"""
chunk_json = chunk.model_dump_json()
message_id = "0-0"
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
# Write to Redis Stream for persistence and real-time delivery
raw_id = await redis.xadd(
stream_key,
{"data": chunk_json},
maxlen=config.stream_max_length,
)
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
# Set TTL on stream to match task metadata TTL
await redis.expire(stream_key, config.stream_ttl)
except Exception as e:
logger.error(
f"Failed to publish chunk for task {task_id}: {e}",
exc_info=True,
)
return message_id
async def subscribe_to_task(
task_id: str,
user_id: str | None,
last_message_id: str = "0-0",
) -> asyncio.Queue[StreamBaseResponse] | None:
"""Subscribe to a task's stream with replay of missed messages.
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
Args:
task_id: Task ID to subscribe to
user_id: User ID for ownership validation
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
Returns:
An asyncio Queue that will receive stream chunks, or None if task not found
or user doesn't have access
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
logger.debug(f"Task {task_id} not found in Redis")
return None
# Note: Redis client uses decode_responses=True, so keys are strings
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
# Validate ownership - if task has an owner, requester must match
if task_user_id:
if user_id != task_user_id:
logger.warning(
f"User {user_id} denied access to task {task_id} "
f"owned by {task_user_id}"
)
return None
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_task_stream_key(task_id)
# Step 1: Replay messages from Redis Stream
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
replayed_count = 0
replay_last_id = last_message_id
if messages:
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
# Note: Redis client uses decode_responses=True, so keys are strings
if "data" in msg_data:
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
await subscriber_queue.put(chunk)
replayed_count += 1
except Exception as e:
logger.warning(f"Failed to replay message: {e}")
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
# Step 2: If task is still running, start stream listener for live updates
if task_status == "running":
listener_task = asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
# Track listener task for cleanup on unsubscribe
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
else:
# Task is completed/failed - add finish marker
await subscriber_queue.put(StreamFinish())
return subscriber_queue
async def _stream_listener(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
last_replayed_id: str,
) -> None:
"""Listen to Redis Stream for new messages using blocking XREAD.
This approach avoids the duplicate message issue that can occur with pub/sub
when messages are published during the gap between replay and subscription.
Args:
task_id: Task ID to listen for
subscriber_queue: Queue to deliver messages to
last_replayed_id: Last message ID from replay (continue from here)
"""
queue_id = id(subscriber_queue)
# Track the last successfully delivered message ID for recovery hints
last_delivered_id = last_replayed_id
try:
redis = await get_redis_async()
stream_key = _get_task_stream_key(task_id)
current_id = last_replayed_id
while True:
# Block for up to 30 seconds waiting for new messages
# This allows periodic checking if task is still running
messages = await redis.xread(
{stream_key: current_id}, block=30000, count=100
)
if not messages:
# Timeout - check if task is still running
meta_key = _get_task_meta_key(task_id)
status = await redis.hget(meta_key, "status") # type: ignore[misc]
if status and status != "running":
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except asyncio.TimeoutError:
logger.warning(
f"Timeout delivering finish event for task {task_id}"
)
break
continue
for _stream_name, stream_messages in messages:
for msg_id, msg_data in stream_messages:
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
if "data" not in msg_data:
continue
try:
chunk_data = orjson.loads(msg_data["data"])
chunk = _reconstruct_chunk(chunk_data)
if chunk:
try:
await asyncio.wait_for(
subscriber_queue.put(chunk),
timeout=QUEUE_PUT_TIMEOUT,
)
# Update last delivered ID on successful delivery
last_delivered_id = current_id
except asyncio.TimeoutError:
logger.warning(
f"Subscriber queue full for task {task_id}, "
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
)
# Send overflow error with recovery info
try:
overflow_error = StreamError(
errorText="Message delivery timeout - some messages may have been missed",
code="QUEUE_OVERFLOW",
details={
"last_delivered_id": last_delivered_id,
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
},
)
subscriber_queue.put_nowait(overflow_error)
except asyncio.QueueFull:
# Queue is completely stuck, nothing more we can do
logger.error(
f"Cannot deliver overflow error for task {task_id}, "
"queue completely blocked"
)
# Stop listening on finish
if isinstance(chunk, StreamFinish):
return
except Exception as e:
logger.warning(f"Error processing stream message: {e}")
except asyncio.CancelledError:
logger.debug(f"Stream listener cancelled for task {task_id}")
raise # Re-raise to propagate cancellation
except Exception as e:
logger.error(f"Stream listener error for task {task_id}: {e}")
# On error, send finish to unblock subscriber
try:
await asyncio.wait_for(
subscriber_queue.put(StreamFinish()),
timeout=QUEUE_PUT_TIMEOUT,
)
except (asyncio.TimeoutError, asyncio.QueueFull):
logger.warning(
f"Could not deliver finish event for task {task_id} after error"
)
finally:
# Clean up listener task mapping on exit
_listener_tasks.pop(queue_id, None)
async def mark_task_completed(
task_id: str,
status: Literal["completed", "failed"] = "completed",
) -> bool:
"""Mark a task as completed and publish finish event.
This is idempotent - calling multiple times with the same task_id is safe.
Uses atomic compare-and-swap via Lua script to prevent race conditions.
Status is updated first (source of truth), then finish event is published (best-effort).
Args:
task_id: Task ID to mark as completed
status: Final status ("completed" or "failed")
Returns:
True if task was newly marked completed, False if already completed/failed
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
# Atomic compare-and-swap: only update if status is "running"
# This prevents race conditions when multiple callers try to complete simultaneously
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
if result == 0:
logger.debug(f"Task {task_id} already completed/failed, skipping")
return False
# THEN publish finish event (best-effort - listeners can detect via status polling)
try:
await publish_chunk(task_id, StreamFinish())
except Exception as e:
logger.error(
f"Failed to publish finish event for task {task_id}: {e}. "
"Listeners will detect completion via status polling."
)
# Clean up local task reference if exists
_local_tasks.pop(task_id, None)
return True
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
"""Find a task by its operation ID.
Used by webhook callbacks to locate the task to update.
Args:
operation_id: Operation ID to search for
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
op_key = _get_operation_mapping_key(operation_id)
task_id = await redis.get(op_key)
if not task_id:
return None
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
return await get_task(task_id_str)
async def get_task(task_id: str) -> ActiveTask | None:
"""Get a task by its ID from Redis.
Args:
task_id: Task ID to look up
Returns:
ActiveTask if found, None otherwise
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
return None
# Note: Redis client uses decode_responses=True, so keys/values are strings
return ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
)
async def get_task_with_expiry_info(
task_id: str,
) -> tuple[ActiveTask | None, str | None]:
"""Get a task by its ID with expiration detection.
Returns (task, error_code) where error_code is:
- None if task found
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
- "TASK_NOT_FOUND" if neither exists
Args:
task_id: Task ID to look up
Returns:
Tuple of (ActiveTask or None, error_code or None)
"""
redis = await get_redis_async()
meta_key = _get_task_meta_key(task_id)
stream_key = _get_task_stream_key(task_id)
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
if not meta:
# Check if stream still has data (metadata expired but stream hasn't)
stream_len = await redis.xlen(stream_key)
if stream_len > 0:
return None, "TASK_EXPIRED"
return None, "TASK_NOT_FOUND"
# Note: Redis client uses decode_responses=True, so keys/values are strings
return (
ActiveTask(
task_id=meta.get("task_id", ""),
session_id=meta.get("session_id", ""),
user_id=meta.get("user_id", "") or None,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status=meta.get("status", "running"), # type: ignore[arg-type]
),
None,
)
async def get_active_task_for_session(
session_id: str,
user_id: str | None = None,
) -> tuple[ActiveTask | None, str]:
"""Get the active (running) task for a session, if any.
Scans Redis for tasks matching the session_id with status="running".
Args:
session_id: Session ID to look up
user_id: User ID for ownership validation (optional)
Returns:
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
"""
redis = await get_redis_async()
# Scan Redis for task metadata keys
cursor = 0
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{config.task_meta_prefix}*", count=100
)
for key in keys:
tasks_checked += 1
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
if not meta:
continue
# Note: Redis client uses decode_responses=True, so keys/values are strings
task_session_id = meta.get("session_id", "")
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
task_id = meta.get("task_id", "")
if task_session_id == session_id and task_status == "running":
# Validate ownership - if task has an owner, requester must match
if task_user_id and user_id != task_user_id:
continue
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"
try:
messages = await redis.xrevrange(stream_key, count=1)
if messages:
msg_id = messages[0][0]
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
except Exception as e:
logger.warning(f"Failed to get last message ID: {e}")
return (
ActiveTask(
task_id=task_id,
session_id=task_session_id,
user_id=task_user_id,
tool_call_id=meta.get("tool_call_id", ""),
tool_name=meta.get("tool_name", ""),
operation_id=meta.get("operation_id", ""),
status="running",
),
last_id,
)
if cursor == 0:
break
return None, "0-0"
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
"""Reconstruct a StreamBaseResponse from JSON data.
Args:
chunk_data: Parsed JSON data from Redis
Returns:
Reconstructed response object, or None if unknown type
"""
from .response_model import (
ResponseType,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
# Map response types to their corresponding classes
type_to_class: dict[str, type[StreamBaseResponse]] = {
ResponseType.START.value: StreamStart,
ResponseType.FINISH.value: StreamFinish,
ResponseType.TEXT_START.value: StreamTextStart,
ResponseType.TEXT_DELTA.value: StreamTextDelta,
ResponseType.TEXT_END.value: StreamTextEnd,
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
ResponseType.ERROR.value: StreamError,
ResponseType.USAGE.value: StreamUsage,
ResponseType.HEARTBEAT.value: StreamHeartbeat,
}
chunk_type = chunk_data.get("type")
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
if chunk_class is None:
logger.warning(f"Unknown chunk type: {chunk_type}")
return None
try:
return chunk_class(**chunk_data)
except Exception as e:
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
return None
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
"""Track the asyncio.Task for a task (local reference only).
This is just for cleanup purposes - the task state is in Redis.
Args:
task_id: Task ID
asyncio_task: The asyncio Task to track
"""
_local_tasks[task_id] = asyncio_task
async def unsubscribe_from_task(
task_id: str,
subscriber_queue: asyncio.Queue[StreamBaseResponse],
) -> None:
"""Clean up when a subscriber disconnects.
Cancels the XREAD-based listener task associated with this subscriber queue
to prevent resource leaks.
Args:
task_id: Task ID
subscriber_queue: The subscriber's queue used to look up the listener task
"""
queue_id = id(subscriber_queue)
listener_entry = _listener_tasks.pop(queue_id, None)
if listener_entry is None:
logger.debug(
f"No listener task found for task {task_id} queue {queue_id} "
"(may have already completed)"
)
return
stored_task_id, listener_task = listener_entry
if stored_task_id != task_id:
logger.warning(
f"Task ID mismatch in unsubscribe: expected {task_id}, "
f"found {stored_task_id}"
)
if listener_task.done():
logger.debug(f"Listener task for task {task_id} already completed")
return
# Cancel the listener task
listener_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(listener_task, timeout=5.0)
except asyncio.CancelledError:
# Expected - the task was successfully cancelled
pass
except asyncio.TimeoutError:
logger.warning(
f"Timeout waiting for listener task cancellation for task {task_id}"
)
except Exception as e:
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
logger.debug(f"Successfully unsubscribed from task {task_id}")

View File

@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool from .agent_output import AgentOutputTool
from .base import BaseTool from .base import BaseTool
from .create_agent import CreateAgentTool from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool from .edit_agent import EditAgentTool
from .find_agent import FindAgentTool from .find_agent import FindAgentTool
from .find_block import FindBlockTool from .find_block import FindBlockTool
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
TOOL_REGISTRY: dict[str, BaseTool] = { TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(), "add_understanding": AddUnderstandingTool(),
"create_agent": CreateAgentTool(), "create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"edit_agent": EditAgentTool(), "edit_agent": EditAgentTool(),
"find_agent": FindAgentTool(), "find_agent": FindAgentTool(),
"find_block": FindBlockTool(), "find_block": FindBlockTool(),

View File

@@ -8,7 +8,6 @@ from .core import (
DecompositionStep, DecompositionStep,
LibraryAgentSummary, LibraryAgentSummary,
MarketplaceAgentSummary, MarketplaceAgentSummary,
customize_template,
decompose_goal, decompose_goal,
enrich_library_agents_from_steps, enrich_library_agents_from_steps,
extract_search_terms_from_steps, extract_search_terms_from_steps,
@@ -20,7 +19,6 @@ from .core import (
get_library_agent_by_graph_id, get_library_agent_by_graph_id,
get_library_agent_by_id, get_library_agent_by_id,
get_library_agents_for_generation, get_library_agents_for_generation,
graph_to_json,
json_to_graph, json_to_graph,
save_agent_to_library, save_agent_to_library,
search_marketplace_agents_for_generation, search_marketplace_agents_for_generation,
@@ -38,7 +36,6 @@ __all__ = [
"LibraryAgentSummary", "LibraryAgentSummary",
"MarketplaceAgentSummary", "MarketplaceAgentSummary",
"check_external_service_health", "check_external_service_health",
"customize_template",
"decompose_goal", "decompose_goal",
"enrich_library_agents_from_steps", "enrich_library_agents_from_steps",
"extract_search_terms_from_steps", "extract_search_terms_from_steps",
@@ -51,7 +48,6 @@ __all__ = [
"get_library_agent_by_id", "get_library_agent_by_id",
"get_library_agents_for_generation", "get_library_agents_for_generation",
"get_user_message_for_error", "get_user_message_for_error",
"graph_to_json",
"is_external_service_configured", "is_external_service_configured",
"json_to_graph", "json_to_graph",
"save_agent_to_library", "save_agent_to_library",

View File

@@ -19,7 +19,6 @@ from backend.data.graph import (
from backend.util.exceptions import DatabaseError, NotFoundError from backend.util.exceptions import DatabaseError, NotFoundError
from .service import ( from .service import (
customize_template_external,
decompose_goal_external, decompose_goal_external,
generate_agent_external, generate_agent_external,
generate_agent_patch_external, generate_agent_patch_external,
@@ -550,21 +549,15 @@ async def decompose_goal(
async def generate_agent( async def generate_agent(
instructions: DecompositionResult | dict[str, Any], instructions: DecompositionResult | dict[str, Any],
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Generate agent JSON from instructions. """Generate agent JSON from instructions.
Args: Args:
instructions: Structured instructions from decompose_goal instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns: Returns:
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Agent JSON dict, error dict {"type": "error", ...}, or None on error
Raises: Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured. AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -572,13 +565,8 @@ async def generate_agent(
_check_service_configured() _check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent") logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external( result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents), operation_id, task_id dict(instructions), _to_dict_list(library_agents)
) )
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result: if result:
if isinstance(result, dict) and result.get("type") == "error": if isinstance(result, dict) and result.get("type") == "error":
return result return result
@@ -752,15 +740,32 @@ async def save_agent_to_library(
return created_graph, library_agents[0] return created_graph, library_agents[0]
def graph_to_json(graph: Graph) -> dict[str, Any]: async def get_agent_as_json(
"""Convert a Graph object to JSON format for the agent generator. agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args: Args:
graph: Graph object to convert agent_id: Graph ID or library agent ID
user_id: User ID
Returns: Returns:
Agent as JSON dict Agent as JSON dict or None if not found
""" """
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
nodes = [] nodes = []
for node in graph.nodes: for node in graph.nodes:
nodes.append( nodes.append(
@@ -797,41 +802,10 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
} }
async def get_agent_as_json(
agent_id: str, user_id: str | None
) -> dict[str, Any] | None:
"""Fetch an agent and convert to JSON format for editing.
Args:
agent_id: Graph ID or library agent ID
user_id: User ID
Returns:
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
pass
if not graph:
return None
return graph_to_json(graph)
async def generate_agent_patch( async def generate_agent_patch(
update_request: str, update_request: str,
current_agent: dict[str, Any], current_agent: dict[str, Any],
library_agents: list[AgentSummary] | None = None, library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Update an existing agent using natural language. """Update an existing agent using natural language.
@@ -844,12 +818,10 @@ async def generate_agent_patch(
update_request: Natural language description of changes update_request: Natural language description of changes
current_agent: Current agent JSON current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns: Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error error dict {"type": "error", ...}, or None on unexpected error
Raises: Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured. AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -857,43 +829,5 @@ async def generate_agent_patch(
_check_service_configured() _check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch") logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external( return await generate_agent_patch_external(
update_request, update_request, current_agent, _to_dict_list(library_agents)
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)
async def customize_template(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Customize a template/marketplace agent using natural language.
This is used when users want to modify a template or marketplace agent
to fit their specific needs before adding it to their library.
The external Agent Generator service handles:
- Understanding the modification request
- Applying changes to the template
- Fixing and validating the result
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on unexpected error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
"""
_check_service_configured()
logger.info("Calling external Agent Generator service for customize_template")
return await customize_template_external(
template_agent, modification_request, context
) )

View File

@@ -139,10 +139,11 @@ async def decompose_goal_external(
""" """
client = _get_client() client = _get_client()
if context: # Build the request payload
description = f"{description}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {"description": description} payload: dict[str, Any] = {"description": description}
if context:
# The external service uses user_instruction for additional context
payload["user_instruction"] = context
if library_agents: if library_agents:
payload["library_agents"] = library_agents payload["library_agents"] = library_agents
@@ -212,45 +213,24 @@ async def decompose_goal_external(
async def generate_agent_external( async def generate_agent_external(
instructions: dict[str, Any], instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None, library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions. """Call the external service to generate an agent from instructions.
Args: Args:
instructions: Structured instructions from decompose_goal instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns: Returns:
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error Agent JSON dict on success, or error dict {"type": "error", ...} on error
""" """
client = _get_client() client = _get_client()
# Build request payload
payload: dict[str, Any] = {"instructions": instructions} payload: dict[str, Any] = {"instructions": instructions}
if library_agents: if library_agents:
payload["library_agents"] = library_agents payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try: try:
response = await client.post("/api/generate-agent", json=payload) response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -282,8 +262,6 @@ async def generate_agent_patch_external(
update_request: str, update_request: str,
current_agent: dict[str, Any], current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None, library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent. """Call the external service to generate a patch for an existing agent.
@@ -291,40 +269,21 @@ async def generate_agent_patch_external(
update_request: Natural language description of changes update_request: Natural language description of changes
current_agent: Current agent JSON current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns: Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error Updated agent JSON, clarifying questions dict, or error dict on error
""" """
client = _get_client() client = _get_client()
# Build request payload
payload: dict[str, Any] = { payload: dict[str, Any] = {
"update_request": update_request, "update_request": update_request,
"current_agent_json": current_agent, "current_agent_json": current_agent,
} }
if library_agents: if library_agents:
payload["library_agents"] = library_agents payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try: try:
response = await client.post("/api/update-agent", json=payload) response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -368,77 +327,6 @@ async def generate_agent_patch_external(
return _create_error_response(error_msg, "unexpected_error") return _create_error_response(error_msg, "unexpected_error")
async def customize_template_external(
template_agent: dict[str, Any],
modification_request: str,
context: str = "",
) -> dict[str, Any] | None:
"""Call the external service to customize a template/marketplace agent.
Args:
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error
"""
client = _get_client()
request = modification_request
if context:
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
payload: dict[str, Any] = {
"template_agent_json": template_agent,
"modification_request": request,
}
try:
response = await client.post("/api/template-modification", json=payload)
response.raise_for_status()
data = response.json()
if not data.get("success"):
error_msg = data.get("error", "Unknown error from Agent Generator")
error_type = data.get("error_type", "unknown")
logger.error(
f"Agent Generator template customization failed: {error_msg} "
f"(type: {error_type})"
)
return _create_error_response(error_msg, error_type)
# Check if it's clarifying questions
if data.get("type") == "clarifying_questions":
return {
"type": "clarifying_questions",
"questions": data.get("questions", []),
}
# Check if it's an error passed through
if data.get("type") == "error":
return _create_error_response(
data.get("error", "Unknown error"),
data.get("error_type", "unknown"),
)
# Otherwise return the customized agent JSON
return data.get("agent_json")
except httpx.HTTPStatusError as e:
error_type, error_msg = _classify_http_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except httpx.RequestError as e:
error_type, error_msg = _classify_request_error(e)
logger.error(error_msg)
return _create_error_response(error_msg, error_type)
except Exception as e:
error_msg = f"Unexpected error calling Agent Generator: {e}"
logger.error(error_msg)
return _create_error_response(error_msg, "unexpected_error")
async def get_blocks_external() -> list[dict[str, Any]] | None: async def get_blocks_external() -> list[dict[str, Any]] | None:
"""Get available blocks from the external service. """Get available blocks from the external service.

View File

@@ -18,7 +18,6 @@ from .base import BaseTool
from .models import ( from .models import (
AgentPreviewResponse, AgentPreviewResponse,
AgentSavedResponse, AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse, ClarificationNeededResponse,
ClarifyingQuestion, ClarifyingQuestion,
ErrorResponse, ErrorResponse,
@@ -99,10 +98,6 @@ class CreateAgentTool(BaseTool):
save = kwargs.get("save", True) save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description: if not description:
return ErrorResponse( return ErrorResponse(
message="Please provide a description of what the agent should do.", message="Please provide a description of what the agent should do.",
@@ -224,12 +219,7 @@ class CreateAgentTool(BaseTool):
logger.warning(f"Failed to enrich library agents from steps: {e}") logger.warning(f"Failed to enrich library agents from steps: {e}")
try: try:
agent_json = await generate_agent( agent_json = await generate_agent(decomposition_result, library_agents)
decomposition_result,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
message=( message=(
@@ -273,19 +263,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent") agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "") agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", [])) node_count = len(agent_json.get("nodes", []))

View File

@@ -1,337 +0,0 @@
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from .agent_generator import (
AgentGeneratorNotConfiguredError,
customize_template,
get_user_message_for_error,
graph_to_json,
save_agent_to_library,
)
from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@property
def name(self) -> str:
return "customize_agent"
@property
def description(self) -> str:
return (
"Customize a marketplace or template agent using natural language. "
"Takes an existing agent from the marketplace and modifies it based on "
"the user's requirements before adding to their library."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": (
"The marketplace agent ID in format 'creator/slug' "
"(e.g., 'autogpt/newsletter-writer'). "
"Get this from find_agent results."
),
},
"modifications": {
"type": "string",
"description": (
"Natural language description of how to customize the agent. "
"Be specific about what changes you want to make."
),
},
"context": {
"type": "string",
"description": (
"Additional context or answers to previous clarifying questions."
),
},
"save": {
"type": "boolean",
"description": (
"Whether to save the customized agent to the user's library. "
"Default is true. Set to false for preview only."
),
"default": True,
},
},
"required": ["agent_id", "modifications"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
Flow:
1. Parse the agent ID to get creator/slug
2. Fetch the template agent from the marketplace
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
if not agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
session_id=session_id,
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
error="invalid_agent_id_format",
session_id=session_id,
)
creator_username, agent_slug = parts
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(
username=creator_username, agent_name=agent_slug
)
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
session_id=session_id,
)
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
session_id=session_id,
)
# Get the full agent graph
try:
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
session_id=session_id,
)
# Call customize_template
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
message=(
"Agent customization is not available. "
"The Agent Generator service is not configured."
),
error="service_not_configured",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
"Please try again."
),
error="customization_service_error",
session_id=session_id,
)
if result is None:
return ErrorResponse(
message=(
"Failed to customize the agent. "
"The agent generation service may be unavailable or timed out. "
"Please try again."
),
error="customization_failed",
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
message="Failed to customize the agent due to an unexpected response.",
error="unexpected_response_type",
session_id=session_id,
)
customized_agent = result
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
agent_description = customized_agent.get("description", "")
nodes = customized_agent.get("nodes")
links = customized_agent.get("links")
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "
f"The customized agent has {node_count} blocks. "
f"Review it and call customize_agent with save=true to save it."
),
agent_json=customized_agent,
agent_name=agent_name,
description=agent_description,
node_count=node_count,
link_count=link_count,
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="You must be logged in to save agents.",
error="auth_required",
session_id=session_id,
)
# Save to user's library
try:
created_graph, library_agent = await save_agent_to_library(
customized_agent, user_id, is_update=False
)
return AgentSavedResponse(
message=(
f"Customized agent '{created_graph.name}' "
f"(based on '{agent_details.agent_name}') "
f"has been saved to your library!"
),
agent_id=created_graph.id,
agent_name=created_graph.name,
library_agent_id=library_agent.id,
library_agent_link=f"/library/agents/{library_agent.id}",
agent_page_link=f"/build?flowID={created_graph.id}",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error saving customized agent: {e}")
return ErrorResponse(
message="Failed to save the customized agent. Please try again.",
error="save_failed",
session_id=session_id,
)

View File

@@ -17,7 +17,6 @@ from .base import BaseTool
from .models import ( from .models import (
AgentPreviewResponse, AgentPreviewResponse,
AgentSavedResponse, AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse, ClarificationNeededResponse,
ClarifyingQuestion, ClarifyingQuestion,
ErrorResponse, ErrorResponse,
@@ -105,10 +104,6 @@ class EditAgentTool(BaseTool):
save = kwargs.get("save", True) save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id: if not agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the agent ID to edit.", message="Please provide the agent ID to edit.",
@@ -154,11 +149,7 @@ class EditAgentTool(BaseTool):
try: try:
result = await generate_agent_patch( result = await generate_agent_patch(
update_request, update_request, current_agent, library_agents
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -178,20 +169,6 @@ class EditAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error": if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error") error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown") error_type = result.get("error_type", "unknown")

View File

@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
OPERATION_STARTED = "operation_started" OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending" OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress" OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Base response model # Base response model
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
has_external_trigger: bool | None = None has_external_trigger: bool | None = None
new_output: bool | None = None new_output: bool | None = None
graph_id: str | None = None graph_id: str | None = None
inputs: dict[str, Any] | None = Field(
default=None,
description="Input schema for the agent, including field names, types, and defaults",
)
class AgentsFoundResponse(ToolResponseBase): class AgentsFoundResponse(ToolResponseBase):
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
details: dict[str, Any] | None = None details: dict[str, Any] | None = None
class InputValidationErrorResponse(ToolResponseBase):
"""Response when run_agent receives unknown input fields."""
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
unrecognized_fields: list[str] = Field(
description="List of input field names that were not recognized"
)
inputs: dict[str, Any] = Field(
description="The agent's valid input schema for reference"
)
graph_id: str | None = None
graph_version: int | None = None
# Agent output models # Agent output models
class ExecutionOutputInfo(BaseModel): class ExecutionOutputInfo(BaseModel):
"""Summary of a single execution's outputs.""" """Summary of a single execution's outputs."""
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
This is returned immediately to the client while the operation continues This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later. to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
""" """
type: ResponseType = ResponseType.OPERATION_STARTED type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str operation_id: str
tool_name: str tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase): class OperationPendingResponse(ToolResponseBase):
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
tool_call_id: str tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None

View File

@@ -30,7 +30,6 @@ from .models import (
ErrorResponse, ErrorResponse,
ExecutionOptions, ExecutionOptions,
ExecutionStartedResponse, ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo, SetupInfo,
SetupRequirementsResponse, SetupRequirementsResponse,
ToolResponseBase, ToolResponseBase,
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
input_properties = graph.input_schema.get("properties", {}) input_properties = graph.input_schema.get("properties", {})
required_fields = set(graph.input_schema.get("required", [])) required_fields = set(graph.input_schema.get("required", []))
provided_inputs = set(params.inputs.keys()) provided_inputs = set(params.inputs.keys())
valid_fields = set(input_properties.keys())
# Check for unknown input fields
unrecognized_fields = provided_inputs - valid_fields
if unrecognized_fields:
return InputValidationErrorResponse(
message=(
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
f"Agent was not executed. Please use the correct field names from the schema."
),
session_id=session_id,
unrecognized_fields=sorted(unrecognized_fields),
inputs=graph.input_schema,
graph_id=graph.id,
graph_version=graph.version,
)
# If agent has inputs but none were provided AND use_defaults is not set, # If agent has inputs but none were provided AND use_defaults is not set,
# always show what's available first so user can decide # always show what's available first so user can decide

View File

@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
# Should return error about missing schedule_name # Should return error about missing schedule_name
assert result_data.get("type") == "error" assert result_data.get("type") == "error"
assert "schedule_name" in result_data["message"].lower() assert "schedule_name" in result_data["message"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
"""Test that run_agent returns input_validation_error for unknown input fields."""
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
tool = RunAgentTool()
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute with unknown input field names
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={
"unknown_field": "some value",
"another_unknown": "another value",
},
session=session,
)
assert response is not None
assert hasattr(response, "output")
assert isinstance(response.output, str)
result_data = orjson.loads(response.output)
# Should return input_validation_error type with unrecognized fields
assert result_data.get("type") == "input_validation_error"
assert "unrecognized_fields" in result_data
assert set(result_data["unrecognized_fields"]) == {
"another_unknown",
"unknown_field",
}
assert "inputs" in result_data # Contains the valid schema
assert "Agent was not executed" in result_data["message"]

View File

@@ -5,8 +5,6 @@ import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block from backend.data.block import get_block
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
self, self,
user_id: str, user_id: str,
block: Any, block: Any,
input_data: dict[str, Any] | None = None,
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
""" """
Check if user has required credentials for a block. Check if user has required credentials for a block.
Args:
user_id: User ID
block: Block to check credentials for
input_data: Input data for the block (used to determine provider via discriminator)
Returns: Returns:
tuple[matched_credentials, missing_credentials] tuple[matched_credentials, missing_credentials]
""" """
matched_credentials: dict[str, CredentialsMetaInput] = {} matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = [] missing_credentials: list[CredentialsMetaInput] = []
input_data = input_data or {}
# Get credential field info from block's input schema # Get credential field info from block's input schema
credentials_fields_info = block.input_schema.get_credentials_fields_info() credentials_fields_info = block.input_schema.get_credentials_fields_info()
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
available_creds = await creds_manager.store.get_all_creds(user_id) available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items(): for field_name, field_info in credentials_fields_info.items():
effective_field_info = field_info # field_info.provider is a frozenset of acceptable providers
if field_info.discriminator and field_info.discriminator_mapping: # field_info.supported_types is a frozenset of acceptable types
# Get discriminator from input, falling back to schema default
discriminator_value = input_data.get(field_info.discriminator)
if discriminator_value is None:
field = block.input_schema.model_fields.get(
field_info.discriminator
)
if field and field.default is not PydanticUndefined:
discriminator_value = field.default
if (
discriminator_value
and discriminator_value in field_info.discriminator_mapping
):
effective_field_info = field_info.discriminate(discriminator_value)
logger.debug(
f"Discriminated provider for {field_name}: "
f"{discriminator_value} -> {effective_field_info.provider}"
)
matching_cred = next( matching_cred = next(
( (
cred cred
for cred in available_creds for cred in available_creds
if cred.provider in effective_field_info.provider if cred.provider in field_info.provider
and cred.type in effective_field_info.supported_types and cred.type in field_info.supported_types
), ),
None, None,
) )
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
) )
else: else:
# Create a placeholder for the missing credential # Create a placeholder for the missing credential
provider = next(iter(effective_field_info.provider), "unknown") provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(effective_field_info.supported_types), "api_key") cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append( missing_credentials.append(
CredentialsMetaInput( CredentialsMetaInput(
id=field_name, id=field_name,
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
# Check credentials
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials( matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data user_id, block
) )
if missing_credentials: if missing_credentials:

View File

@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
cleanup_embeddings: list, cleanup_embeddings: list,
): ):
"""Test unified search pagination works correctly.""" """Test unified search pagination works correctly."""
# Use a unique search term to avoid matching other test data
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
# Create multiple items # Create multiple items
content_ids = [] content_ids = []
for i in range(5): for i in range(5):
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
content_type=ContentType.BLOCK, content_type=ContentType.BLOCK,
content_id=content_id, content_id=content_id,
embedding=mock_embedding, embedding=mock_embedding,
searchable_text=f"{unique_term} item number {i}", searchable_text=f"pagination test item number {i}",
metadata={"index": i}, metadata={"index": i},
user_id=None, user_id=None,
) )
# Get first page # Get first page
page1_results, total1 = await unified_hybrid_search( page1_results, total1 = await unified_hybrid_search(
query=unique_term, query="pagination test",
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=2, page_size=2,
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
# Get second page # Get second page
page2_results, total2 = await unified_hybrid_search( page2_results, total2 = await unified_hybrid_search(
query=unique_term, query="pagination test",
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=2, page=2,
page_size=2, page_size=2,

View File

@@ -40,10 +40,6 @@ import backend.data.user
import backend.integrations.webhooks.utils import backend.integrations.webhooks.utils
import backend.util.service import backend.util.service
import backend.util.settings import backend.util.settings
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials from backend.data.model import Credentials
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context(): with launch_darkly_context():
yield yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try: try:
await shutdown_cloud_storage_handler() await shutdown_cloud_storage_handler()
except Exception as e: except Exception as e:

View File

@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
execution_bus = AsyncRedisExecutionEventBus() execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus() notification_bus = AsyncRedisNotificationEventBus()
try: async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def execution_worker(): async def notification_worker():
async for event in execution_bus.listen("*"): async for notification in notification_bus.listen("*"):
await manager.send_execution_update(event) await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
async def notification_worker(): await asyncio.gather(execution_worker(), notification_worker())
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
finally:
# Ensure PubSub connections are closed on any exit to prevent leaks
await execution_bus.close()
await notification_bus.close()
async def authenticate_websocket(websocket: WebSocket) -> str: async def authenticate_websocket(websocket: WebSocket) -> str:

View File

@@ -32,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util import json from backend.util import json
from backend.util.logging import TruncatedLogger from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -634,18 +634,11 @@ async def llm_call(
context_window = llm_model.context_window context_window = llm_model.context_window
if compress_prompt_to_fit: if compress_prompt_to_fit:
result = await compress_context( prompt = compress_prompt(
messages=prompt, messages=prompt,
target_tokens=llm_model.context_window // 2, target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization lossy_ok=True,
reserve=0, # Caller handles response token budget separately
) )
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length # Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt) estimated_input_tokens = estimate_token_count(prompt)

View File

@@ -873,13 +873,14 @@ def is_block_auth_configured(
async def initialize_blocks() -> None: async def initialize_blocks() -> None:
# First, sync all provider costs to blocks
# Imported here to avoid circular import
from backend.sdk.cost_integration import sync_all_provider_costs from backend.sdk.cost_integration import sync_all_provider_costs
from backend.util.retry import func_retry
sync_all_provider_costs() sync_all_provider_costs()
@func_retry for cls in get_blocks().values():
async def sync_block_to_db(block: Block) -> None: block = cls()
existing_block = await AgentBlock.prisma().find_first( existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]} where={"OR": [{"id": block.id}, {"name": block.name}]}
) )
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
outputSchema=json.dumps(block.output_schema.jsonschema()), outputSchema=json.dumps(block.output_schema.jsonschema()),
) )
) )
return continue
input_schema = json.dumps(block.input_schema.jsonschema()) input_schema = json.dumps(block.input_schema.jsonschema())
output_schema = json.dumps(block.output_schema.jsonschema()) output_schema = json.dumps(block.output_schema.jsonschema())
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
}, },
) )
failed_blocks: list[str] = []
for cls in get_blocks().values():
block = cls()
try:
await sync_block_to_db(block)
except Exception as e:
logger.warning(
f"Failed to sync block {block.name} to database: {e}. "
"Block is still available in memory.",
exc_info=True,
)
failed_blocks.append(block.name)
if failed_blocks:
logger.error(
f"Failed to sync {len(failed_blocks)} block(s) to database: "
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
)
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281 # Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
def get_block(block_id: str) -> AnyBlockSchema | None: def get_block(block_id: str) -> AnyBlockSchema | None:

View File

@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
def __init__(self):
self._pubsub: AsyncPubSub | None = None
@property @property
async def connection(self) -> redis.AsyncRedis: async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async() return await redis.get_redis_async()
async def close(self) -> None:
"""Close the PubSub connection if it exists."""
if self._pubsub is not None:
try:
await self._pubsub.close()
except Exception:
logger.warning("Failed to close PubSub connection", exc_info=True)
finally:
self._pubsub = None
async def publish_event(self, event: M, channel_key: str): async def publish_event(self, event: M, channel_key: str):
""" """
Publish an event to Redis. Gracefully handles connection failures Publish an event to Redis. Gracefully handles connection failures
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
await self.connection, channel_key await self.connection, channel_key
) )
assert isinstance(pubsub, AsyncPubSub) assert isinstance(pubsub, AsyncPubSub)
self._pubsub = pubsub
if "*" in channel_key: if "*" in channel_key:
await pubsub.psubscribe(full_channel_name) await pubsub.psubscribe(full_channel_name)

View File

@@ -17,7 +17,6 @@ from backend.data.analytics import (
get_accuracy_trends_and_alerts, get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring, get_marketplace_graphs_for_monitoring,
) )
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import ( from backend.data.execution import (
create_graph_execution, create_graph_execution,
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
# Onboarding # Onboarding
increment_onboarding_runs = _(increment_onboarding_runs) increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store # Store
get_store_agents = _(get_store_agents) get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details) get_store_agent_details = _(get_store_agent_details)
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# Onboarding # Onboarding
increment_onboarding_runs = d.increment_onboarding_runs increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store # Store
get_store_agents = d.get_store_agents get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details get_store_agent_details = d.get_store_agent_details

View File

@@ -24,9 +24,11 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
from backend.monitoring import ( from backend.monitoring import (
NotificationJobArgs, NotificationJobArgs,
@@ -36,11 +38,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts, report_execution_accuracy_alerts,
report_late_executions, report_late_executions,
) )
from backend.util.clients import ( from backend.util.clients import get_database_manager_client, get_scheduler_client
get_database_manager_async_client,
get_database_manager_client,
get_scheduler_client,
)
from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import ( from backend.util.exceptions import (
GraphNotFoundError, GraphNotFoundError,
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs): async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs) args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
db = get_database_manager_async_client()
try: try:
logger.info(f"Executing recurring job for graph #{args.graph_id}") logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution( graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data, inputs=args.input_data,
graph_credentials_inputs=args.input_credentials, graph_credentials_inputs=args.input_credentials,
) )
await db.increment_onboarding_runs(args.user_id) await increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time elapsed = asyncio.get_event_loop().time() - start_time
logger.info( logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -196,9 +193,11 @@ async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
user_id=args.user_id, user_id=args.user_id,
) )
else: else:
logger.error( logger.warning(
f"Unable to unschedule graph: {args.graph_id} as this is an old job with no associated schedule_id please remove manually" f"Old scheduled job for graph {args.graph_id} (user {args.user_id}) "
f"has no schedule_id, attempting targeted cleanup"
) )
await _cleanup_old_schedules_without_id(args.graph_id, args.user_id)
async def _handle_graph_not_available( async def _handle_graph_not_available(
@@ -241,6 +240,35 @@ async def _cleanup_orphaned_schedules_for_graph(graph_id: str, user_id: str) ->
) )
async def _cleanup_old_schedules_without_id(graph_id: str, user_id: str) -> None:
"""Remove only schedules that have no schedule_id in their job args.
Unlike _cleanup_orphaned_schedules_for_graph (which removes ALL schedules
for a graph), this only targets legacy jobs created before schedule_id was
added to GraphExecutionJobArgs, preserving any valid newer schedules.
"""
scheduler_client = get_scheduler_client()
schedules = await scheduler_client.get_execution_schedules(
graph_id=graph_id, user_id=user_id
)
for schedule in schedules:
if schedule.schedule_id is not None:
continue
try:
await scheduler_client.delete_schedule(
schedule_id=schedule.id, user_id=user_id
)
logger.info(
f"Cleaned up old schedule {schedule.id} (no schedule_id) "
f"for graph {graph_id}"
)
except Exception:
logger.exception(
f"Failed to delete old schedule {schedule.id} for graph {graph_id}"
)
def cleanup_expired_files(): def cleanup_expired_files():
"""Clean up expired files from cloud storage.""" """Clean up expired files from cloud storage."""
# Wait for completion # Wait for completion
@@ -249,13 +277,8 @@ def cleanup_expired_files():
def cleanup_oauth_tokens(): def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database.""" """Clean up expired OAuth tokens from the database."""
# Wait for completion # Wait for completion
async def _cleanup(): run_async(cleanup_expired_oauth_tokens())
db = get_database_manager_async_client()
return await db.cleanup_expired_oauth_tokens()
run_async(_cleanup())
def execution_accuracy_alerts(): def execution_accuracy_alerts():

View File

@@ -1,19 +1,10 @@
from __future__ import annotations
import logging
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from typing import Any
from typing import TYPE_CHECKING, Any
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
if TYPE_CHECKING:
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
# CONSTANTS # # CONSTANTS #
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None: def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
""" """
Carefully truncate tool message content while preserving tool structure. Carefully truncate tool message content while preserving tool structure.
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages. Only truncates tool_result content, leaves tool_use intact.
""" """
content = msg.get("content") content = msg.get("content")
# OpenAI-style tool message: role="tool" with string content
if msg.get("role") == "tool" and isinstance(content, str):
if _tok_len(content, enc) > max_tokens:
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
return
# Anthropic-style: list content with tool_result items
if not isinstance(content, list): if not isinstance(content, list):
return return
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
# Keep first and last messages intact (unless they're tool messages)
if i == 0 or i == len(msgs) - 1:
continue
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _is_tool_message(m):
# For tool messages, only truncate tool result content, preserve structure
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
# Never truncate objective messages - they contain the core task
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
deletable_indices = []
for i in range(1, len(msgs) - 1): # Skip first and last
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
deletable_indices.append(i)
if not deletable_indices:
break # nothing more we can drop
# Delete from center outward - find the index closest to center
centre = len(msgs) // 2
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
del msgs[to_delete]
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
if _is_tool_message(msgs[idx]):
# For tool messages at first/last position, truncate tool result content only
_truncate_tool_message_content(msgs[idx], enc, cap)
continue
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count( def estimate_token_count(
messages: list[dict], messages: list[dict],
*, *,
@@ -175,8 +293,7 @@ def estimate_token_count(
------- -------
int Token count. int Token count.
""" """
token_model = _normalize_model_for_tokenizer(model) enc = encoding_for_model(model) # best-match tokenizer
enc = encoding_for_model(token_model)
return sum(_msg_tokens(m, enc) for m in messages) return sum(_msg_tokens(m, enc) for m in messages)
@@ -198,543 +315,6 @@ def estimate_token_count_str(
------- -------
int Token count. int Token count.
""" """
token_model = _normalize_model_for_tokenizer(model) enc = encoding_for_model(model) # best-match tokenizer
enc = encoding_for_model(token_model)
text = json.dumps(text) if not isinstance(text, str) else text text = json.dumps(text) if not isinstance(text, str) else text
return _tok_len(text, enc) return _tok_len(text, enc)
# ---------------------------------------------------------------------------#
# UNIFIED CONTEXT COMPRESSION #
# ---------------------------------------------------------------------------#
# Default thresholds
DEFAULT_TOKEN_THRESHOLD = 120_000
DEFAULT_KEEP_RECENT = 15
@dataclass
class CompressResult:
"""Result of context compression."""
messages: list[dict]
token_count: int
was_compacted: bool
error: str | None = None
original_token_count: int = 0
messages_summarized: int = 0
messages_dropped: int = 0
def _normalize_model_for_tokenizer(model: str) -> str:
"""Normalize model name for tiktoken tokenizer selection."""
if "/" in model:
model = model.split("/")[-1]
if "claude" in model.lower() or not any(
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
):
return "gpt-4o"
return model
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs from an assistant message.
Supports both formats:
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
Returns:
Set of tool_call IDs found in the message.
"""
ids: set[str] = set()
if msg.get("role") != "assistant":
return ids
# OpenAI format: tool_calls array
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_use blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tc_id = block.get("id")
if tc_id:
ids.add(tc_id)
return ids
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs that this message is responding to.
Supports both formats:
- OpenAI: {"role": "tool", "tool_call_id": "..."}
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
Returns:
Set of tool_call IDs this message responds to.
"""
ids: set[str] = set()
# OpenAI format: role=tool with tool_call_id
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_result blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
tc_id = block.get("tool_use_id")
if tc_id:
ids.add(tc_id)
return ids
def _is_tool_response_message(msg: dict) -> bool:
"""Check if message is a tool response (OpenAI or Anthropic format)."""
# OpenAI format
if msg.get("role") == "tool":
return True
# Anthropic format
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
return True
return False
def _remove_orphan_tool_responses(
messages: list[dict], orphan_ids: set[str]
) -> list[dict]:
"""
Remove tool response messages/blocks that reference orphan tool_call IDs.
Supports both OpenAI and Anthropic formats.
For Anthropic messages with mixed valid/orphan tool_result blocks,
filters out only the orphan blocks instead of dropping the entire message.
"""
result = []
for msg in messages:
# OpenAI format: role=tool - drop entire message if orphan
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id in orphan_ids:
continue
result.append(msg)
continue
# Anthropic format: content list may have mixed tool_result blocks
content = msg.get("content")
if isinstance(content, list):
has_tool_results = any(
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
)
if has_tool_results:
# Filter out orphan tool_result blocks, keep valid ones
filtered_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") == "tool_result"
and block.get("tool_use_id") in orphan_ids
)
]
# Only keep message if it has remaining content
if filtered_content:
msg = msg.copy()
msg["content"] = filtered_content
result.append(msg)
continue
result.append(msg)
return result
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
start_index: int,
) -> list[dict]:
"""
Ensure tool_call/tool_response pairs stay together after slicing.
When slicing messages for context compaction, a naive slice can separate
an assistant message containing tool_calls from its corresponding tool
response messages. This causes API validation errors (e.g., Anthropic's
"unexpected tool_use_id found in tool_result blocks").
This function checks for orphan tool responses in the slice and extends
backwards to include their corresponding assistant messages.
Supports both formats:
- OpenAI: tool_calls array + role="tool" responses
- Anthropic: tool_use blocks + tool_result blocks
Args:
recent_messages: The sliced messages to validate
all_messages: The complete message list (for looking up missing assistants)
start_index: The index in all_messages where recent_messages begins
Returns:
A potentially extended list of messages with tool pairs intact
"""
if not recent_messages:
return recent_messages
# Collect all tool_call_ids from assistant messages in the slice
available_tool_call_ids: set[str] = set()
for msg in recent_messages:
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
# Find orphan tool responses (responses whose tool_call_id is missing)
orphan_tool_call_ids: set[str] = set()
for msg in recent_messages:
response_ids = _extract_tool_response_ids_from_message(msg)
for tc_id in response_ids:
if tc_id not in available_tool_call_ids:
orphan_tool_call_ids.add(tc_id)
if not orphan_tool_call_ids:
# No orphans, slice is valid
return recent_messages
# Find the assistant messages that contain the orphan tool_call_ids
# Search backwards from start_index in all_messages
messages_to_prepend: list[dict] = []
for i in range(start_index - 1, -1, -1):
msg = all_messages[i]
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
if msg_tool_ids & orphan_tool_call_ids:
# This assistant message has tool_calls we need
# Also collect its contiguous tool responses that follow it
assistant_and_responses: list[dict] = [msg]
# Scan forward from this assistant to collect tool responses
for j in range(i + 1, start_index):
following_msg = all_messages[j]
following_response_ids = _extract_tool_response_ids_from_message(
following_msg
)
if following_response_ids and following_response_ids & msg_tool_ids:
assistant_and_responses.append(following_msg)
elif not _is_tool_response_message(following_msg):
# Stop at first non-tool-response message
break
# Prepend the assistant and its tool responses (maintain order)
messages_to_prepend = assistant_and_responses + messages_to_prepend
# Mark these as found
orphan_tool_call_ids -= msg_tool_ids
# Also add this assistant's tool_call_ids to available set
available_tool_call_ids |= msg_tool_ids
if not orphan_tool_call_ids:
# Found all missing assistants
break
if orphan_tool_call_ids:
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
)
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
return recent_messages
async def _summarize_messages_llm(
messages: list[dict],
client: AsyncOpenAI,
model: str,
timeout: float = 30.0,
) -> str:
"""Summarize messages using an LLM."""
conversation = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if content and role in ("user", "assistant", "tool"):
conversation.append(f"{role.upper()}: {content}")
conversation_text = "\n\n".join(conversation)
if not conversation_text:
return "No conversation history available."
# Limit to ~100k chars for safety
MAX_CHARS = 100_000
if len(conversation_text) > MAX_CHARS:
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
response = await client.with_options(timeout=timeout).chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"Create a detailed summary of the conversation so far. "
"This summary will be used as context when continuing the conversation.\n\n"
"Before writing the summary, analyze each message chronologically to identify:\n"
"- User requests and their explicit goals\n"
"- Your approach and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"You MUST include ALL of the following sections:\n\n"
"## 1. Primary Request and Intent\n"
"The user's explicit goals and what they are trying to accomplish.\n\n"
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions. "
"Include any user feedback on fixes.\n\n"
"## 5. Problem Solving\n"
"Issues that have been resolved and how they were addressed.\n\n"
"## 6. All User Messages\n"
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
"## 7. Pending Tasks\n"
"Work items the user explicitly requested that have not yet been completed.\n\n"
"## 8. Current Work\n"
"Precise description of what was being worked on most recently, including relevant context.\n\n"
"## 9. Next Steps\n"
"What should happen next, aligned with the user's most recent requests. "
"Include verbatim quotes of recent instructions if relevant."
),
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=1500,
temperature=0.3,
)
return response.choices[0].message.content or "No summary available."
async def compress_context(
messages: list[dict],
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
*,
model: str = "gpt-4o",
client: AsyncOpenAI | None = None,
keep_recent: int = DEFAULT_KEEP_RECENT,
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
) -> CompressResult:
"""
Unified context compression that combines summarization and truncation strategies.
Strategy (in order):
1. **LLM summarization** If client provided, summarize old messages into a
single context message while keeping recent messages intact. This is the
primary strategy for chat service.
2. **Content truncation** Progressively halve a per-message cap and truncate
bloated message content (tool outputs, large pastes). Preserves all messages
but shortens their content. Primary strategy when client=None (LLM blocks).
3. **Middle-out deletion** Delete whole messages one at a time from the center
outward, skipping tool messages and objective messages.
4. **First/last trim** Truncate first and last message content as last resort.
Parameters
----------
messages Complete chat history (will be deep-copied).
target_tokens Hard ceiling for prompt size.
model Model name for tokenization and summarization.
client AsyncOpenAI client. If provided, enables LLM summarization
as the first strategy. If None, skips to truncation strategies.
keep_recent Number of recent messages to preserve during summarization.
reserve Tokens to reserve for model response.
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap before moving to deletions.
Returns
-------
CompressResult with compressed messages and metadata.
"""
# Guard clause for empty messages
if not messages:
return CompressResult(
messages=[],
token_count=0,
was_compacted=False,
original_token_count=0,
)
token_model = _normalize_model_for_tokenizer(model)
enc = encoding_for_model(token_model)
msgs = deepcopy(messages)
def total_tokens() -> int:
return sum(_msg_tokens(m, enc) for m in msgs)
original_count = total_tokens()
# Already under limit
if original_count + reserve <= target_tokens:
return CompressResult(
messages=msgs,
token_count=original_count,
was_compacted=False,
original_token_count=original_count,
)
messages_summarized = 0
messages_dropped = 0
# ---- STEP 1: LLM summarization (if client provided) -------------------
# This is the primary compression strategy for chat service.
# Summarize old messages while keeping recent ones intact.
if client is not None:
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
system_msg = msgs[0] if has_system else None
# Calculate old vs recent messages
if has_system:
if len(msgs) > keep_recent + 1:
old_msgs = msgs[1:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs[1:] if len(msgs) > 1 else []
else:
if len(msgs) > keep_recent:
old_msgs = msgs[:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs
# Ensure tool pairs stay intact
slice_start = max(0, len(msgs) - keep_recent)
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
if old_msgs:
try:
summary_text = await _summarize_messages_llm(old_msgs, client, model)
summary_msg = {
"role": "assistant",
"content": f"[Previous conversation summary — for context only]: {summary_text}",
}
messages_summarized = len(old_msgs)
if has_system:
msgs = [system_msg, summary_msg] + recent_msgs
else:
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
# Always run this before truncation to ensure consistent token counting.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
if i == 0 or i == len(msgs) - 1:
continue
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 3: Token-aware content truncation ---------------------------
# Progressively halve per-message cap and truncate bloated content.
# This preserves all messages but shortens their content.
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]:
if _is_tool_message(m):
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2
# ---- STEP 4: Middle-out deletion --------------------------------------
# Delete messages one at a time from the center outward.
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
msg is not None
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
deletable.append(i)
if not deletable:
break
centre = len(msgs) // 2
to_delete = min(deletable, key=lambda i: abs(i - centre))
del msgs[to_delete]
messages_dropped += 1
# ---- STEP 5: Final trim on first/last ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1):
msg = msgs[idx]
if msg is None:
continue
if _is_tool_message(msg):
_truncate_tool_message_content(msg, enc, cap)
continue
text = msg.get("content") or ""
if _tok_len(text, enc) > cap:
msg["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2
# Filter out any None values that may have been introduced
final_msgs: list[dict] = [m for m in msgs if m is not None]
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
error = None
if final_count + reserve > target_tokens:
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
logger.warning(error)
return CompressResult(
messages=final_msgs,
token_count=final_count,
was_compacted=True,
error=error,
original_token_count=original_count,
messages_summarized=messages_summarized,
messages_dropped=messages_dropped,
)

View File

@@ -1,21 +1,10 @@
"""Tests for prompt utility functions, especially tool call token counting.""" """Tests for prompt utility functions, especially tool call token counting."""
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
from backend.util.prompt import ( from backend.util.prompt import _msg_tokens, estimate_token_count
CompressResult,
_ensure_tool_pairs_intact,
_msg_tokens,
_normalize_model_for_tokenizer,
_truncate_middle_tokens,
_truncate_tool_message_content,
compress_context,
estimate_token_count,
)
class TestMsgTokens: class TestMsgTokens:
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
assert total_tokens == expected_total assert total_tokens == expected_total
assert total_tokens > 20 # Should be substantial assert total_tokens > 20 # Should be substantial
class TestNormalizeModelForTokenizer:
"""Test model name normalization for tiktoken."""
def test_openai_models_unchanged(self):
"""Test that OpenAI models are returned as-is."""
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
def test_claude_models_normalized(self):
"""Test that Claude models are normalized to gpt-4o."""
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
def test_openrouter_paths_extracted(self):
"""Test that OpenRouter model paths are handled."""
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
def test_unknown_models_default_to_gpt4o(self):
"""Test that unknown models default to gpt-4o."""
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
class TestTruncateToolMessageContent:
"""Test tool message content truncation."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncate_openai_tool_message(self, enc):
"""Test truncation of OpenAI-style tool message with string content."""
long_content = "x" * 10000
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
assert len(msg["content"]) < len(long_content)
assert "" in msg["content"] # Has ellipsis marker
def test_truncate_anthropic_tool_result(self, enc):
"""Test truncation of Anthropic-style tool_result."""
long_content = "y" * 10000
msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": long_content,
}
],
}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
result_content = msg["content"][0]["content"]
assert len(result_content) < len(long_content)
assert "" in result_content
def test_preserve_tool_use_blocks(self, enc):
"""Test that tool_use blocks are not truncated."""
msg = {
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "some_function",
"input": {"key": "value" * 1000}, # Large input
}
],
}
original = json.dumps(msg["content"][0]["input"])
_truncate_tool_message_content(msg, enc, max_tokens=10)
# tool_use should be unchanged
assert json.dumps(msg["content"][0]["input"]) == original
def test_no_truncation_when_under_limit(self, enc):
"""Test that short content is not modified."""
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
original = msg["content"]
_truncate_tool_message_content(msg, enc, max_tokens=1000)
assert msg["content"] == original
class TestTruncateMiddleTokens:
"""Test middle truncation of text."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncates_long_text(self, enc):
"""Test that long text is truncated with ellipsis in middle."""
long_text = "word " * 1000
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
assert "" in result
assert result.startswith("word") # Head preserved
assert result.endswith("word ") # Tail preserved
def test_preserves_short_text(self, enc):
"""Test that short text is not modified."""
short_text = "Hello world"
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
assert result == short_text
class TestEnsureToolPairsIntact:
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
# ---- OpenAI Format Tests ----
def test_openai_adds_missing_tool_call(self):
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool response)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_call message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert "tool_calls" in result[0]
def test_openai_keeps_complete_pairs(self):
"""Test that complete OpenAI pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
]
recent = all_msgs[1:] # Include both tool_call and response
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_openai_multiple_tool_calls(self):
"""Test multiple OpenAI tool calls in one assistant message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (first tool response)
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_calls
assert len(result) == 4
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 2
# ---- Anthropic Format Tests ----
def test_anthropic_adds_missing_tool_use(self):
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "get_weather",
"input": {"location": "SF"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": "22°C and sunny",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_anthropic_keeps_complete_pairs(self):
"""Test that complete Anthropic pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_456",
"name": "calculator",
"input": {"expr": "2+2"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_456",
"content": "4",
}
],
},
]
recent = all_msgs[1:] # Include both tool_use and result
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_anthropic_multiple_tool_uses(self):
"""Test multiple Anthropic tool_use blocks in one message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check both..."},
{
"type": "tool_use",
"id": "toolu_1",
"name": "get_weather",
"input": {"city": "NYC"},
},
{
"type": "tool_use",
"id": "toolu_2",
"name": "get_weather",
"input": {"city": "LA"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_1",
"content": "Cold",
},
{
"type": "tool_result",
"tool_use_id": "toolu_2",
"content": "Warm",
},
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_uses
assert len(result) == 3
assert result[0]["role"] == "assistant"
tool_use_count = sum(
1 for b in result[0]["content"] if b.get("type") == "tool_use"
)
assert tool_use_count == 2
# ---- Mixed/Edge Case Tests ----
def test_anthropic_with_type_message_field(self):
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_abc",
"name": "search",
"input": {"q": "test"},
}
],
},
{
"role": "user",
"type": "message", # Extra field from smart_decision_maker
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_abc",
"content": "Found results",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result with 'type': 'message')
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_handles_no_tool_messages(self):
"""Test messages without tool calls."""
all_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
recent = all_msgs
start_index = 0
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert result == all_msgs
def test_handles_empty_messages(self):
"""Test empty message list."""
result = _ensure_tool_pairs_intact([], [], 0)
assert result == []
def test_mixed_text_and_tool_content(self):
"""Test Anthropic message with mixed text and tool_use content."""
all_msgs = [
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'll help you with that."},
{
"type": "tool_use",
"id": "toolu_mixed",
"name": "search",
"input": {"q": "test"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_mixed",
"content": "Found results",
}
],
},
{"role": "assistant", "content": "Here are the results..."},
]
# Start from tool_result
recent = [all_msgs[1], all_msgs[2]]
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with tool_use
assert len(result) == 3
assert result[0]["content"][0]["type"] == "text"
assert result[0]["content"][1]["type"] == "tool_use"
class TestCompressContext:
"""Test the async compress_context function."""
@pytest.mark.asyncio
async def test_no_compression_needed(self):
"""Test messages under limit return without compression."""
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"},
]
result = await compress_context(messages, target_tokens=100000)
assert isinstance(result, CompressResult)
assert result.was_compacted is False
assert len(result.messages) == 2
assert result.error is None
@pytest.mark.asyncio
async def test_truncation_without_client(self):
"""Test that truncation works without LLM client."""
long_content = "x" * 50000
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": long_content},
{"role": "assistant", "content": "Response"},
]
result = await compress_context(
messages, target_tokens=1000, client=None, reserve=100
)
assert result.was_compacted is True
# Should have truncated without summarization
assert result.messages_summarized == 0
@pytest.mark.asyncio
async def test_with_mocked_llm_client(self):
"""Test summarization with mocked LLM client."""
# Create many messages to trigger summarization
messages = [{"role": "system", "content": "System prompt"}]
for i in range(30):
messages.append({"role": "user", "content": f"User message {i} " * 100})
messages.append(
{"role": "assistant", "content": f"Assistant response {i} " * 100}
)
# Mock the AsyncOpenAI client
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Summary of conversation"
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await compress_context(
messages,
target_tokens=5000,
client=mock_client,
keep_recent=5,
reserve=500,
)
assert result.was_compacted is True
# Should have attempted summarization
assert mock_client.with_options.called or result.messages_summarized > 0
@pytest.mark.asyncio
async def test_preserves_tool_pairs(self):
"""Test that tool call/response pairs stay together."""
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "func"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
{"role": "assistant", "content": "Done!"},
]
result = await compress_context(
messages, target_tokens=500, client=None, reserve=50
)
# Check that if tool response exists, its call exists too
tool_call_ids = set()
tool_response_ids = set()
for msg in result.messages:
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
tool_call_ids.add(tc["id"])
if msg.get("role") == "tool":
tool_response_ids.add(msg.get("tool_call_id"))
# All tool responses should have their calls
assert tool_response_ids <= tool_call_ids
@pytest.mark.asyncio
async def test_returns_error_when_cannot_compress(self):
"""Test that error is returned when compression fails."""
# Single huge message that can't be compressed enough
messages = [
{"role": "user", "content": "x" * 100000},
]
result = await compress_context(
messages, target_tokens=100, client=None, reserve=50
)
# Should have an error since we can't get below 100 tokens
assert result.error is not None
assert result.was_compacted is True
@pytest.mark.asyncio
async def test_empty_messages(self):
"""Test that empty messages list returns early without error."""
result = await compress_context([], target_tokens=1000)
assert result.messages == []
assert result.token_count == 0
assert result.was_compacted is False
assert result.error is None
class TestRemoveOrphanToolResponses:
"""Test _remove_orphan_tool_responses helper function."""
def test_removes_openai_orphan(self):
"""Test removal of orphan OpenAI tool response."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
{"role": "user", "content": "Hello"},
]
orphan_ids = {"call_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_keeps_valid_openai_tool(self):
"""Test that valid OpenAI tool responses are kept."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
]
orphan_ids = {"call_other"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_valid"
def test_filters_anthropic_mixed_blocks(self):
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_valid",
"content": "valid result",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan",
"content": "orphan result",
},
],
},
]
orphan_ids = {"toolu_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
# Should only have the valid tool_result, orphan filtered out
assert len(result[0]["content"]) == 1
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
def test_removes_anthropic_all_orphan(self):
"""Test removal of Anthropic message when all tool_results are orphans."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_orphan1",
"content": "result1",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan2",
"content": "result2",
},
],
},
]
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
# Message should be completely removed since no content left
assert len(result) == 0
def test_preserves_non_tool_messages(self):
"""Test that non-tool messages are preserved."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
orphan_ids = {"some_id"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert result == messages
class TestCompressResultDataclass:
"""Test CompressResult dataclass."""
def test_default_values(self):
"""Test default values are set correctly."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=10,
was_compacted=False,
)
assert result.error is None
assert result.original_token_count == 0 # Defaults to 0, not None
assert result.messages_summarized == 0
assert result.messages_dropped == 0
def test_all_fields(self):
"""Test all fields can be set."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=100,
was_compacted=True,
error="Some error",
original_token_count=500,
messages_summarized=10,
messages_dropped=5,
)
assert result.token_count == 100
assert result.was_compacted is True
assert result.error == "Some error"
assert result.original_token_count == 500
assert result.messages_summarized == 10
assert result.messages_dropped == 5

View File

@@ -111,7 +111,9 @@ class TestGenerateAgent:
instructions = {"type": "instructions", "steps": ["Step 1"]} instructions = {"type": "instructions", "steps": ["Step 1"]}
result = await core.generate_agent(instructions) result = await core.generate_agent(instructions)
mock_external.assert_called_once_with(instructions, None, None, None) # library_agents defaults to None
mock_external.assert_called_once_with(instructions, None)
# Result should have id, version, is_active added if not present
assert result is not None assert result is not None
assert result["name"] == "Test Agent" assert result["name"] == "Test Agent"
assert "id" in result assert "id" in result
@@ -175,9 +177,8 @@ class TestGenerateAgentPatch:
current_agent = {"nodes": [], "links": []} current_agent = {"nodes": [], "links": []}
result = await core.generate_agent_patch("Add a node", current_agent) result = await core.generate_agent_patch("Add a node", current_agent)
mock_external.assert_called_once_with( # library_agents defaults to None
"Add a node", current_agent, None, None, None mock_external.assert_called_once_with("Add a node", current_agent, None)
)
assert result == expected_result assert result == expected_result
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decompose_goal_with_context(self): async def test_decompose_goal_with_context(self):
"""Test decomposition with additional context enriched into description.""" """Test decomposition with additional context."""
mock_response = MagicMock() mock_response = MagicMock()
mock_response.json.return_value = { mock_response.json.return_value = {
"success": True, "success": True,
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
"Build a chatbot", context="Use Python" "Build a chatbot", context="Use Python"
) )
expected_description = (
"Build a chatbot\n\nAdditional context from user:\nUse Python"
)
mock_client.post.assert_called_once_with( mock_client.post.assert_called_once_with(
"/api/decompose-description", "/api/decompose-description",
json={"description": expected_description}, json={"description": "Build a chatbot", "user_instruction": "Use Python"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -1,9 +1,10 @@
"use client"; "use client";
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useEffect } from "react"; import { useEffect } from "react";
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
import { getHomepageRoute } from "@/lib/constants";
export default function OnboardingPage() { export default function OnboardingPage() {
const router = useRouter(); const router = useRouter();
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
async function redirectToStep() { async function redirectToStep() {
try { try {
// Check if onboarding is enabled (also gets chat flag for redirect) // Check if onboarding is enabled (also gets chat flag for redirect)
const { shouldShowOnboarding } = await getOnboardingStatus(); const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
const homepageRoute = getHomepageRoute(isChatEnabled);
if (!shouldShowOnboarding) { if (!shouldShowOnboarding) {
router.replace("/"); router.replace(homepageRoute);
return; return;
} }
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
// Handle completed onboarding // Handle completed onboarding
if (onboarding.completedSteps.includes("GET_RESULTS")) { if (onboarding.completedSteps.includes("GET_RESULTS")) {
router.replace("/"); router.replace(homepageRoute);
return; return;
} }

View File

@@ -1,8 +1,9 @@
import { getOnboardingStatus } from "@/app/api/helpers";
import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { revalidatePath } from "next/cache"; import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { revalidatePath } from "next/cache";
import { getOnboardingStatus } from "@/app/api/helpers";
// Handle the callback to complete the user session login // Handle the callback to complete the user session login
export async function GET(request: Request) { export async function GET(request: Request) {
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
await api.createUser(); await api.createUser();
// Get onboarding status from backend (includes chat flag evaluated for this user) // Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus(); const { shouldShowOnboarding, isChatEnabled } =
await getOnboardingStatus();
if (shouldShowOnboarding) { if (shouldShowOnboarding) {
next = "/onboarding"; next = "/onboarding";
revalidatePath("/onboarding", "layout"); revalidatePath("/onboarding", "layout");
} else { } else {
next = "/"; next = getHomepageRoute(isChatEnabled);
revalidatePath(next, "layout"); revalidatePath(next, "layout");
} }
} catch (createUserError) { } catch (createUserError) {

View File

@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { usePathname, useSearchParams } from "next/navigation"; import { usePathname, useSearchParams } from "next/navigation";
import { useRef } from "react";
import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotStore } from "../../copilot-page-store";
import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useCopilotSessionId } from "../../useCopilotSessionId";
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
@@ -69,16 +70,41 @@ export function useCopilotShell() {
}); });
const stopStream = useChatStore((s) => s.stopStream); const stopStream = useChatStore((s) => s.stopStream);
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
const isStreaming = useCopilotStore((s) => s.isStreaming);
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
function handleSessionClick(sessionId: string) { const pendingActionRef = useRef<(() => void) | null>(null);
if (sessionId === currentSessionId) return;
// Stop current stream - SSE reconnection allows resuming later async function stopCurrentStream() {
if (currentSessionId) { if (!currentSessionId) return;
setIsSwitchingSession(true);
await new Promise<void>((resolve) => {
const unsubscribe = onStreamComplete((completedId) => {
if (completedId === currentSessionId) {
clearTimeout(timeout);
unsubscribe();
resolve();
}
});
const timeout = setTimeout(() => {
unsubscribe();
resolve();
}, 3000);
stopStream(currentSessionId); stopStream(currentSessionId);
} });
queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
});
setIsSwitchingSession(false);
}
function selectSession(sessionId: string) {
if (sessionId === currentSessionId) return;
if (recentlyCreatedSessionsRef.current.has(sessionId)) { if (recentlyCreatedSessionsRef.current.has(sessionId)) {
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV2GetSessionQueryKey(sessionId), queryKey: getGetV2GetSessionQueryKey(sessionId),
@@ -88,12 +114,7 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer(); if (isMobile) handleCloseDrawer();
} }
function handleNewChatClick() { function startNewChat() {
// Stop current stream - SSE reconnection allows resuming later
if (currentSessionId) {
stopStream(currentSessionId);
}
resetPagination(); resetPagination();
queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(), queryKey: getGetV2ListSessionsQueryKey(),
@@ -102,6 +123,32 @@ export function useCopilotShell() {
if (isMobile) handleCloseDrawer(); if (isMobile) handleCloseDrawer();
} }
function handleSessionClick(sessionId: string) {
if (sessionId === currentSessionId) return;
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
selectSession(sessionId);
};
openInterruptModal(pendingActionRef.current);
} else {
selectSession(sessionId);
}
}
function handleNewChatClick() {
if (isStreaming) {
pendingActionRef.current = async () => {
await stopCurrentStream();
startNewChat();
};
openInterruptModal(pendingActionRef.current);
} else {
startNewChat();
}
}
return { return {
isMobile, isMobile,
isDrawerOpen, isDrawerOpen,

View File

@@ -1,13 +1,6 @@
"use client"; import type { ReactNode } from "react";
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
import { Flag } from "@/services/feature-flags/use-get-flag";
import { type ReactNode } from "react";
import { CopilotShell } from "./components/CopilotShell/CopilotShell"; import { CopilotShell } from "./components/CopilotShell/CopilotShell";
export default function CopilotLayout({ children }: { children: ReactNode }) { export default function CopilotLayout({ children }: { children: ReactNode }) {
return ( return <CopilotShell>{children}</CopilotShell>;
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
<CopilotShell>{children}</CopilotShell>
</FeatureFlagPage>
);
} }

View File

@@ -14,8 +14,14 @@ export default function CopilotPage() {
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } = const {
state; greetingName,
quickActions,
isLoading,
hasSession,
initialPrompt,
isReady,
} = state;
const { const {
handleQuickAction, handleQuickAction,
startChatWithPrompt, startChatWithPrompt,
@@ -23,6 +29,8 @@ export default function CopilotPage() {
handleStreamingChange, handleStreamingChange,
} = handlers; } = handlers;
if (!isReady) return null;
if (hasSession) { if (hasSession) {
return ( return (
<div className="flex h-full flex-col"> <div className="flex h-full flex-col">

View File

@@ -3,11 +3,18 @@ import {
postV2CreateSession, postV2CreateSession,
} from "@/app/api/__generated__/endpoints/chat/chat"; } from "@/app/api/__generated__/endpoints/chat/chat";
import { useToast } from "@/components/molecules/Toast/use-toast"; import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider"; import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import {
Flag,
type FlagValues,
useGetFlag,
} from "@/services/feature-flags/use-get-flag";
import { SessionKey, sessionStorage } from "@/services/storage/session-storage"; import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useFlags } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useEffect } from "react"; import { useEffect } from "react";
import { useCopilotStore } from "./copilot-page-store"; import { useCopilotStore } from "./copilot-page-store";
@@ -26,6 +33,22 @@ export function useCopilotPage() {
const isCreating = useCopilotStore((s) => s.isCreatingSession); const isCreating = useCopilotStore((s) => s.isCreatingSession);
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession); const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
useEffect(() => {
if (isLoggedIn) {
completeStep("VISIT_COPILOT");
}
}, [completeStep, isLoggedIn]);
const isChatEnabled = useGetFlag(Flag.CHAT);
const flags = useFlags<FlagValues>();
const homepageRoute = getHomepageRoute(isChatEnabled);
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
const isFlagReady =
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
const greetingName = getGreetingName(user); const greetingName = getGreetingName(user);
const quickActions = getQuickActions(); const quickActions = getQuickActions();
@@ -35,8 +58,11 @@ export function useCopilotPage() {
: undefined; : undefined;
useEffect(() => { useEffect(() => {
if (isLoggedIn) completeStep("VISIT_COPILOT"); if (!isFlagReady) return;
}, [completeStep, isLoggedIn]); if (isChatEnabled === false) {
router.replace(homepageRoute);
}
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
async function startChatWithPrompt(prompt: string) { async function startChatWithPrompt(prompt: string) {
if (!prompt?.trim()) return; if (!prompt?.trim()) return;
@@ -90,6 +116,7 @@ export function useCopilotPage() {
isLoading: isUserLoading, isLoading: isUserLoading,
hasSession, hasSession,
initialPrompt, initialPrompt,
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
}, },
handlers: { handlers: {
handleQuickAction, handleQuickAction,

View File

@@ -1,6 +1,8 @@
"use client"; "use client";
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
import { getHomepageRoute } from "@/lib/constants";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
import { Suspense } from "react"; import { Suspense } from "react";
import { getErrorDetails } from "./helpers"; import { getErrorDetails } from "./helpers";
@@ -9,6 +11,8 @@ function ErrorPageContent() {
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const errorMessage = searchParams.get("message"); const errorMessage = searchParams.get("message");
const errorDetails = getErrorDetails(errorMessage); const errorDetails = getErrorDetails(errorMessage);
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
function handleRetry() { function handleRetry() {
// Auth-related errors should redirect to login // Auth-related errors should redirect to login
@@ -26,7 +30,7 @@ function ErrorPageContent() {
}, 2000); }, 2000);
} else { } else {
// For server/network errors, go to home // For server/network errors, go to home
window.location.href = "/"; window.location.href = homepageRoute;
} }
} }

View File

@@ -1,5 +1,6 @@
"use server"; "use server";
import { getHomepageRoute } from "@/lib/constants";
import BackendAPI from "@/lib/autogpt-server-api"; import BackendAPI from "@/lib/autogpt-server-api";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { loginFormSchema } from "@/types/auth"; import { loginFormSchema } from "@/types/auth";
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
await api.createUser(); await api.createUser();
// Get onboarding status from backend (includes chat flag evaluated for this user) // Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus(); const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding ? "/onboarding" : "/"; const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);
return { return {
success: true, success: true,

View File

@@ -1,6 +1,8 @@
import { useToast } from "@/components/molecules/Toast/use-toast"; import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment"; import { environment } from "@/services/environment";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { loginFormSchema, LoginProvider } from "@/types/auth"; import { loginFormSchema, LoginProvider } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod"; import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter, useSearchParams } from "next/navigation"; import { useRouter, useSearchParams } from "next/navigation";
@@ -20,15 +22,17 @@ export function useLoginPage() {
const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [isGoogleLoading, setIsGoogleLoading] = useState(false);
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = environment.isCloud(); const isCloudEnv = environment.isCloud();
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
// Get redirect destination from 'next' query parameter // Get redirect destination from 'next' query parameter
const nextUrl = searchParams.get("next"); const nextUrl = searchParams.get("next");
useEffect(() => { useEffect(() => {
if (isLoggedIn && !isLoggingIn) { if (isLoggedIn && !isLoggingIn) {
router.push(nextUrl || "/"); router.push(nextUrl || homepageRoute);
} }
}, [isLoggedIn, isLoggingIn, nextUrl, router]); }, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
const form = useForm<z.infer<typeof loginFormSchema>>({ const form = useForm<z.infer<typeof loginFormSchema>>({
resolver: zodResolver(loginFormSchema), resolver: zodResolver(loginFormSchema),
@@ -94,7 +98,7 @@ export function useLoginPage() {
} }
// Prefer URL's next parameter, then use backend-determined route // Prefer URL's next parameter, then use backend-determined route
router.replace(nextUrl || result.next || "/"); router.replace(nextUrl || result.next || homepageRoute);
} catch (error) { } catch (error) {
toast({ toast({
title: title:

View File

@@ -1,5 +1,6 @@
"use server"; "use server";
import { getHomepageRoute } from "@/lib/constants";
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
import { signupFormSchema } from "@/types/auth"; import { signupFormSchema } from "@/types/auth";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
@@ -58,8 +59,10 @@ export async function signup(
} }
// Get onboarding status from backend (includes chat flag evaluated for this user) // Get onboarding status from backend (includes chat flag evaluated for this user)
const { shouldShowOnboarding } = await getOnboardingStatus(); const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
const next = shouldShowOnboarding ? "/onboarding" : "/"; const next = shouldShowOnboarding
? "/onboarding"
: getHomepageRoute(isChatEnabled);
return { success: true, next }; return { success: true, next };
} catch (err) { } catch (err) {

View File

@@ -1,6 +1,8 @@
import { useToast } from "@/components/molecules/Toast/use-toast"; import { useToast } from "@/components/molecules/Toast/use-toast";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { environment } from "@/services/environment"; import { environment } from "@/services/environment";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { LoginProvider, signupFormSchema } from "@/types/auth"; import { LoginProvider, signupFormSchema } from "@/types/auth";
import { zodResolver } from "@hookform/resolvers/zod"; import { zodResolver } from "@hookform/resolvers/zod";
import { useRouter, useSearchParams } from "next/navigation"; import { useRouter, useSearchParams } from "next/navigation";
@@ -20,15 +22,17 @@ export function useSignupPage() {
const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [isGoogleLoading, setIsGoogleLoading] = useState(false);
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
const isCloudEnv = environment.isCloud(); const isCloudEnv = environment.isCloud();
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
// Get redirect destination from 'next' query parameter // Get redirect destination from 'next' query parameter
const nextUrl = searchParams.get("next"); const nextUrl = searchParams.get("next");
useEffect(() => { useEffect(() => {
if (isLoggedIn && !isSigningUp) { if (isLoggedIn && !isSigningUp) {
router.push(nextUrl || "/"); router.push(nextUrl || homepageRoute);
} }
}, [isLoggedIn, isSigningUp, nextUrl, router]); }, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
const form = useForm<z.infer<typeof signupFormSchema>>({ const form = useForm<z.infer<typeof signupFormSchema>>({
resolver: zodResolver(signupFormSchema), resolver: zodResolver(signupFormSchema),
@@ -129,7 +133,7 @@ export function useSignupPage() {
} }
// Prefer the URL's next parameter, then result.next (for onboarding), then default // Prefer the URL's next parameter, then result.next (for onboarding), then default
const redirectTo = nextUrl || result.next || "/"; const redirectTo = nextUrl || result.next || homepageRoute;
router.replace(redirectTo); router.replace(redirectTo);
} catch (error) { } catch (error) {
setIsLoading(false); setIsLoading(false);

View File

@@ -1,81 +0,0 @@
import { environment } from "@/services/environment";
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
import { NextRequest } from "next/server";
/**
* SSE Proxy for task stream reconnection.
*
* This endpoint allows clients to reconnect to an ongoing or recently completed
* background task's stream. It replays missed messages from Redis Streams and
* subscribes to live updates if the task is still running.
*
* Client contract:
* 1. When receiving an operation_started event, store the task_id
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
* 3. Messages are replayed from the last_message_id position
* 4. Stream ends when "finish" event is received
*/
export async function GET(
request: NextRequest,
{ params }: { params: Promise<{ taskId: string }> },
) {
const { taskId } = await params;
const searchParams = request.nextUrl.searchParams;
const lastMessageId = searchParams.get("last_message_id") || "0-0";
try {
// Get auth token from server-side session
const token = await getServerAuthToken();
// Build backend URL
const backendUrl = environment.getAGPTServerBaseUrl();
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
streamUrl.searchParams.set("last_message_id", lastMessageId);
// Forward request to backend with auth header
const headers: Record<string, string> = {
Accept: "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
};
if (token) {
headers["Authorization"] = `Bearer ${token}`;
}
const response = await fetch(streamUrl.toString(), {
method: "GET",
headers,
});
if (!response.ok) {
const error = await response.text();
return new Response(error, {
status: response.status,
headers: { "Content-Type": "application/json" },
});
}
// Return the SSE stream directly
return new Response(response.body, {
headers: {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache, no-transform",
Connection: "keep-alive",
"X-Accel-Buffering": "no",
},
});
} catch (error) {
console.error("Task stream proxy error:", error);
return new Response(
JSON.stringify({
error: "Failed to connect to task stream",
detail: error instanceof Error ? error.message : String(error),
}),
{
status: 500,
headers: { "Content-Type": "application/json" },
},
);
}
}

View File

@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
const isCompleted = onboarding.completedSteps.includes("CONGRATS"); const isCompleted = onboarding.completedSteps.includes("CONGRATS");
return { return {
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted, shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
isChatEnabled: status.is_chat_enabled,
}; };
} }

View File

@@ -917,28 +917,6 @@
"security": [{ "HTTPBearerJWT": [] }] "security": [{ "HTTPBearerJWT": [] }]
} }
}, },
"/api/chat/config/ttl": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Ttl Config",
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
"operationId": "getV2GetTtlConfig",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"additionalProperties": true,
"type": "object",
"title": "Response Getv2Getttlconfig"
}
}
}
}
}
}
},
"/api/chat/health": { "/api/chat/health": {
"get": { "get": {
"tags": ["v2", "chat", "chat"], "tags": ["v2", "chat", "chat"],
@@ -961,63 +939,6 @@
} }
} }
}, },
"/api/chat/operations/{operation_id}/complete": {
"post": {
"tags": ["v2", "chat", "chat"],
"summary": "Complete Operation",
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
"operationId": "postV2CompleteOperation",
"parameters": [
{
"name": "operation_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Operation Id" }
},
{
"name": "x-api-key",
"in": "header",
"required": false,
"schema": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "X-Api-Key"
}
}
],
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/OperationCompleteRequest"
}
}
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Postv2Completeoperation"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/sessions": { "/api/chat/sessions": {
"get": { "get": {
"tags": ["v2", "chat", "chat"], "tags": ["v2", "chat", "chat"],
@@ -1101,7 +1022,7 @@
"get": { "get": {
"tags": ["v2", "chat", "chat"], "tags": ["v2", "chat", "chat"],
"summary": "Get Session", "summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
"operationId": "getV2GetSession", "operationId": "getV2GetSession",
"security": [{ "HTTPBearerJWT": [] }], "security": [{ "HTTPBearerJWT": [] }],
"parameters": [ "parameters": [
@@ -1236,7 +1157,7 @@
"post": { "post": {
"tags": ["v2", "chat", "chat"], "tags": ["v2", "chat", "chat"],
"summary": "Stream Chat Post", "summary": "Stream Chat Post",
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "postV2StreamChatPost", "operationId": "postV2StreamChatPost",
"security": [{ "HTTPBearerJWT": [] }], "security": [{ "HTTPBearerJWT": [] }],
"parameters": [ "parameters": [
@@ -1274,94 +1195,6 @@
} }
} }
}, },
"/api/chat/tasks/{task_id}": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Task Status",
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
"operationId": "getV2GetTaskStatus",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"type": "object",
"additionalProperties": true,
"title": "Response Getv2Gettaskstatus"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/chat/tasks/{task_id}/stream": {
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Stream Task",
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
"operationId": "getV2StreamTask",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Task Id" }
},
{
"name": "last_message_id",
"in": "query",
"required": false,
"schema": {
"type": "string",
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
"default": "0-0",
"title": "Last Message Id"
},
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": { "application/json": { "schema": {} } }
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
}
},
"/api/credits": { "/api/credits": {
"get": { "get": {
"tags": ["v1", "credits"], "tags": ["v1", "credits"],
@@ -6335,18 +6168,6 @@
"title": "AccuracyTrendsResponse", "title": "AccuracyTrendsResponse",
"description": "Response model for accuracy trends and alerts." "description": "Response model for accuracy trends and alerts."
}, },
"ActiveStreamInfo": {
"properties": {
"task_id": { "type": "string", "title": "Task Id" },
"last_message_id": { "type": "string", "title": "Last Message Id" },
"operation_id": { "type": "string", "title": "Operation Id" },
"tool_name": { "type": "string", "title": "Tool Name" }
},
"type": "object",
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
"title": "ActiveStreamInfo",
"description": "Information about an active stream for reconnection."
},
"AddUserCreditsResponse": { "AddUserCreditsResponse": {
"properties": { "properties": {
"new_balance": { "type": "integer", "title": "New Balance" }, "new_balance": { "type": "integer", "title": "New Balance" },
@@ -9002,27 +8823,6 @@
], ],
"title": "OnboardingStep" "title": "OnboardingStep"
}, },
"OperationCompleteRequest": {
"properties": {
"success": { "type": "boolean", "title": "Success" },
"result": {
"anyOf": [
{ "additionalProperties": true, "type": "object" },
{ "type": "string" },
{ "type": "null" }
],
"title": "Result"
},
"error": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Error"
}
},
"type": "object",
"required": ["success"],
"title": "OperationCompleteRequest",
"description": "Request model for external completion webhook."
},
"Pagination": { "Pagination": {
"properties": { "properties": {
"total_items": { "total_items": {
@@ -9878,12 +9678,6 @@
"items": { "additionalProperties": true, "type": "object" }, "items": { "additionalProperties": true, "type": "object" },
"type": "array", "type": "array",
"title": "Messages" "title": "Messages"
},
"active_stream": {
"anyOf": [
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
{ "type": "null" }
]
} }
}, },
"type": "object", "type": "object",

View File

@@ -1,15 +1,27 @@
"use client"; "use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { getHomepageRoute } from "@/lib/constants";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useEffect } from "react"; import { useEffect } from "react";
export default function Page() { export default function Page() {
const isChatEnabled = useGetFlag(Flag.CHAT);
const router = useRouter(); const router = useRouter();
const homepageRoute = getHomepageRoute(isChatEnabled);
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
const isFlagReady =
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
useEffect(() => { useEffect(
router.replace("/copilot"); function redirectToHomepage() {
}, [router]); if (!isFlagReady) return;
router.replace(homepageRoute);
},
[homepageRoute, isFlagReady, router],
);
return <LoadingSpinner size="large" cover />; return null;
} }

View File

@@ -1,6 +1,7 @@
"use client"; "use client";
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text"; import { Text } from "@/components/atoms/Text/Text";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@@ -24,8 +25,8 @@ export function Chat({
}: ChatProps) { }: ChatProps) {
const { urlSessionId } = useCopilotSessionId(); const { urlSessionId } = useCopilotSessionId();
const hasHandledNotFoundRef = useRef(false); const hasHandledNotFoundRef = useRef(false);
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
const { const {
session,
messages, messages,
isLoading, isLoading,
isCreating, isCreating,
@@ -37,18 +38,6 @@ export function Chat({
startPollingForOperation, startPollingForOperation,
} = useChat({ urlSessionId }); } = useChat({ urlSessionId });
// Extract active stream info for reconnection
const activeStream = (
session as {
active_stream?: {
task_id: string;
last_message_id: string;
operation_id: string;
tool_name: string;
};
}
)?.active_stream;
useEffect(() => { useEffect(() => {
if (!onSessionNotFound) return; if (!onSessionNotFound) return;
if (!urlSessionId) return; if (!urlSessionId) return;
@@ -64,7 +53,8 @@ export function Chat({
isCreating, isCreating,
]); ]);
const shouldShowLoader = showLoader && (isLoading || isCreating); const shouldShowLoader =
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
return ( return (
<div className={cn("flex h-full flex-col", className)}> <div className={cn("flex h-full flex-col", className)}>
@@ -76,19 +66,21 @@ export function Chat({
<div className="flex flex-col items-center gap-3"> <div className="flex flex-col items-center gap-3">
<LoadingSpinner size="large" className="text-neutral-400" /> <LoadingSpinner size="large" className="text-neutral-400" />
<Text variant="body" className="text-zinc-500"> <Text variant="body" className="text-zinc-500">
Loading your chat... {isSwitchingSession
? "Switching chat..."
: "Loading your chat..."}
</Text> </Text>
</div> </div>
</div> </div>
)} )}
{/* Error State */} {/* Error State */}
{error && !isLoading && ( {error && !isLoading && !isSwitchingSession && (
<ChatErrorState error={error} onRetry={createSession} /> <ChatErrorState error={error} onRetry={createSession} />
)} )}
{/* Session Content */} {/* Session Content */}
{sessionId && !isLoading && !error && ( {sessionId && !isLoading && !error && !isSwitchingSession && (
<ChatContainer <ChatContainer
sessionId={sessionId} sessionId={sessionId}
initialMessages={messages} initialMessages={messages}
@@ -96,16 +88,6 @@ export function Chat({
className="flex-1" className="flex-1"
onStreamingChange={onStreamingChange} onStreamingChange={onStreamingChange}
onOperationStarted={startPollingForOperation} onOperationStarted={startPollingForOperation}
activeStream={
activeStream
? {
taskId: activeStream.task_id,
lastMessageId: activeStream.last_message_id,
operationId: activeStream.operation_id,
toolName: activeStream.tool_name,
}
: undefined
}
/> />
)} )}
</main> </main>

View File

@@ -1,159 +0,0 @@
# SSE Reconnection Contract for Long-Running Operations
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
## Overview
When a user triggers a long-running operation (like agent generation), the backend:
1. Spawns a background task that survives SSE disconnections
2. Returns an `operation_started` response with a `task_id`
3. Stores stream messages in Redis Streams for replay
Clients can reconnect to the task stream at any time to receive missed messages.
## Client-Side Flow
### 1. Receiving Operation Started
When you receive an `operation_started` tool response:
```typescript
// The response includes a task_id for reconnection
{
type: "operation_started",
tool_name: "generate_agent",
operation_id: "uuid-...",
task_id: "task-uuid-...", // <-- Store this for reconnection
message: "Operation started. You can close this tab."
}
```
### 2. Storing Task Info
Use the chat store to track the active task:
```typescript
import { useChatStore } from "./chat-store";
// When operation_started is received:
useChatStore.getState().setActiveTask(sessionId, {
taskId: response.task_id,
operationId: response.operation_id,
toolName: response.tool_name,
lastMessageId: "0",
});
```
### 3. Reconnecting to a Task
To reconnect (e.g., after page refresh or tab reopen):
```typescript
const { reconnectToTask, getActiveTask } = useChatStore.getState();
// Check if there's an active task for this session
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Reconnect to the task stream
await reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId, // Resume from last position
(chunk) => {
// Handle incoming chunks
console.log("Received chunk:", chunk);
},
);
}
```
### 4. Tracking Message Position
To enable precise replay, update the last message ID as chunks arrive:
```typescript
const { updateTaskLastMessageId } = useChatStore.getState();
function handleChunk(chunk: StreamChunk) {
// If chunk has an index/id, track it
if (chunk.idx !== undefined) {
updateTaskLastMessageId(sessionId, String(chunk.idx));
}
}
```
## API Endpoints
### Task Stream Reconnection
```
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
```
- `taskId`: The task ID from `operation_started`
- `last_message_id`: Last received message index (default: "0" for full replay)
Returns: SSE stream of missed messages + live updates
## Chunk Types
The reconnected stream follows the same Vercel AI SDK protocol:
| Type | Description |
| ----------------------- | ----------------------- |
| `start` | Message lifecycle start |
| `text-delta` | Streaming text content |
| `text-end` | Text block completed |
| `tool-output-available` | Tool result available |
| `finish` | Stream completed |
| `error` | Error occurred |
## Error Handling
If reconnection fails:
1. Check if task still exists (may have expired - default TTL: 1 hour)
2. Fall back to polling the session for final state
3. Show appropriate UI message to user
## Persistence Considerations
For robust reconnection across browser restarts:
```typescript
// Store in localStorage/sessionStorage
const ACTIVE_TASKS_KEY = "chat_active_tasks";
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
tasks[sessionId] = task;
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
}
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
}
```
## Backend Configuration
The following backend settings affect reconnection behavior:
| Setting | Default | Description |
| ------------------- | ------- | ---------------------------------- |
| `stream_ttl` | 3600s | How long streams are kept in Redis |
| `stream_max_length` | 1000 | Max messages per stream |
## Testing
To test reconnection locally:
1. Start a long-running operation (e.g., agent generation)
2. Note the `task_id` from the `operation_started` response
3. Close the browser tab
4. Reopen and call `reconnectToTask` with the saved `task_id`
5. Verify that missed messages are replayed
See the main README for full local development setup.

View File

@@ -1,16 +0,0 @@
/**
* Constants for the chat system.
*
* Centralizes magic strings and values used across chat components.
*/
// LocalStorage keys
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
// Redis Stream IDs
export const INITIAL_MESSAGE_ID = "0";
export const INITIAL_STREAM_ID = "0-0";
// TTL values (in milliseconds)
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour

View File

@@ -1,12 +1,6 @@
"use client"; "use client";
import { create } from "zustand"; import { create } from "zustand";
import {
ACTIVE_TASK_TTL_MS,
COMPLETED_STREAM_TTL_MS,
INITIAL_STREAM_ID,
STORAGE_KEY_ACTIVE_TASKS,
} from "./chat-constants";
import type { import type {
ActiveStream, ActiveStream,
StreamChunk, StreamChunk,
@@ -14,59 +8,15 @@ import type {
StreamResult, StreamResult,
StreamStatus, StreamStatus,
} from "./chat-types"; } from "./chat-types";
import { executeStream, executeTaskReconnect } from "./stream-executor"; import { executeStream } from "./stream-executor";
export interface ActiveTaskInfo { const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
taskId: string;
sessionId: string;
operationId: string;
toolName: string;
lastMessageId: string;
startedAt: number;
}
/** Load active tasks from localStorage */
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
if (typeof window === "undefined") return new Map();
try {
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
if (!stored) return new Map();
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
const now = Date.now();
const tasks = new Map<string, ActiveTaskInfo>();
// Filter out expired tasks
for (const [sessionId, task] of Object.entries(parsed)) {
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
tasks.set(sessionId, task);
}
}
return tasks;
} catch {
return new Map();
}
}
/** Save active tasks to localStorage */
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
if (typeof window === "undefined") return;
try {
const obj: Record<string, ActiveTaskInfo> = {};
for (const [sessionId, task] of tasks) {
obj[sessionId] = task;
}
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
} catch {
// Ignore storage errors
}
}
interface ChatStoreState { interface ChatStoreState {
activeStreams: Map<string, ActiveStream>; activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>; completedStreams: Map<string, StreamResult>;
activeSessions: Set<string>; activeSessions: Set<string>;
streamCompleteCallbacks: Set<StreamCompleteCallback>; streamCompleteCallbacks: Set<StreamCompleteCallback>;
/** Active tasks for SSE reconnection - keyed by sessionId */
activeTasks: Map<string, ActiveTaskInfo>;
} }
interface ChatStoreActions { interface ChatStoreActions {
@@ -91,24 +41,6 @@ interface ChatStoreActions {
unregisterActiveSession: (sessionId: string) => void; unregisterActiveSession: (sessionId: string) => void;
isSessionActive: (sessionId: string) => boolean; isSessionActive: (sessionId: string) => boolean;
onStreamComplete: (callback: StreamCompleteCallback) => () => void; onStreamComplete: (callback: StreamCompleteCallback) => () => void;
/** Track active task for SSE reconnection */
setActiveTask: (
sessionId: string,
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
) => void;
/** Get active task for a session */
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
/** Clear active task when operation completes */
clearActiveTask: (sessionId: string) => void;
/** Reconnect to an existing task stream */
reconnectToTask: (
sessionId: string,
taskId: string,
lastMessageId?: string,
onChunk?: (chunk: StreamChunk) => void,
) => Promise<void>;
/** Update last message ID for a task (for tracking replay position) */
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
} }
type ChatStore = ChatStoreState & ChatStoreActions; type ChatStore = ChatStoreState & ChatStoreActions;
@@ -132,126 +64,18 @@ function cleanupExpiredStreams(
const now = Date.now(); const now = Date.now();
const cleaned = new Map(completedStreams); const cleaned = new Map(completedStreams);
for (const [sessionId, result] of cleaned) { for (const [sessionId, result] of cleaned) {
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) { if (now - result.completedAt > COMPLETED_STREAM_TTL) {
cleaned.delete(sessionId); cleaned.delete(sessionId);
} }
} }
return cleaned; return cleaned;
} }
/**
* Finalize a stream by moving it from activeStreams to completedStreams.
* Also handles cleanup and notifications.
*/
function finalizeStream(
sessionId: string,
stream: ActiveStream,
onChunk: ((chunk: StreamChunk) => void) | undefined,
get: () => ChatStoreState & ChatStoreActions,
set: (state: Partial<ChatStoreState>) => void,
): void {
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
}
}
}
}
/**
* Clean up an existing stream for a session and move it to completed streams.
* Returns updated maps for both active and completed streams.
*/
function cleanupExistingStream(
sessionId: string,
activeStreams: Map<string, ActiveStream>,
completedStreams: Map<string, StreamResult>,
callbacks: Set<StreamCompleteCallback>,
): {
activeStreams: Map<string, ActiveStream>;
completedStreams: Map<string, StreamResult>;
} {
const newActiveStreams = new Map(activeStreams);
let newCompletedStreams = new Map(completedStreams);
const existingStream = newActiveStreams.get(sessionId);
if (existingStream) {
existingStream.abortController.abort();
const normalizedStatus =
existingStream.status === "streaming"
? "completed"
: existingStream.status;
const result: StreamResult = {
sessionId,
status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
return {
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
};
}
/**
* Create a new active stream with initial state.
*/
function createActiveStream(
sessionId: string,
onChunk?: (chunk: StreamChunk) => void,
): ActiveStream {
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
return {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
}
export const useChatStore = create<ChatStore>((set, get) => ({ export const useChatStore = create<ChatStore>((set, get) => ({
activeStreams: new Map(), activeStreams: new Map(),
completedStreams: new Map(), completedStreams: new Map(),
activeSessions: new Set(), activeSessions: new Set(),
streamCompleteCallbacks: new Set(), streamCompleteCallbacks: new Set(),
activeTasks: loadPersistedTasks(),
startStream: async function startStream( startStream: async function startStream(
sessionId, sessionId,
@@ -261,21 +85,45 @@ export const useChatStore = create<ChatStore>((set, get) => ({
onChunk, onChunk,
) { ) {
const state = get(); const state = get();
const newActiveStreams = new Map(state.activeStreams);
let newCompletedStreams = new Map(state.completedStreams);
const callbacks = state.streamCompleteCallbacks; const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session const existingStream = newActiveStreams.get(sessionId);
const { if (existingStream) {
activeStreams: newActiveStreams, existingStream.abortController.abort();
completedStreams: newCompletedStreams, const normalizedStatus =
} = cleanupExistingStream( existingStream.status === "streaming"
sessionId, ? "completed"
state.activeStreams, : existingStream.status;
state.completedStreams, const result: StreamResult = {
callbacks, sessionId,
); status: normalizedStatus,
chunks: existingStream.chunks,
completedAt: Date.now(),
error: existingStream.error,
};
newCompletedStreams.set(sessionId, result);
newActiveStreams.delete(sessionId);
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
if (normalizedStatus === "completed" || normalizedStatus === "error") {
notifyStreamComplete(callbacks, sessionId);
}
}
const abortController = new AbortController();
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
if (onChunk) initialCallbacks.add(onChunk);
const stream: ActiveStream = {
sessionId,
abortController,
status: "streaming",
startedAt: Date.now(),
chunks: [],
onChunkCallbacks: initialCallbacks,
};
// Create new stream
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream); newActiveStreams.set(sessionId, stream);
set({ set({
activeStreams: newActiveStreams, activeStreams: newActiveStreams,
@@ -285,7 +133,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
try { try {
await executeStream(stream, message, isUserMessage, context); await executeStream(stream, message, isUserMessage, context);
} finally { } finally {
finalizeStream(sessionId, stream, onChunk, get, set); if (onChunk) stream.onChunkCallbacks.delete(onChunk);
if (stream.status !== "streaming") {
const currentState = get();
const finalActiveStreams = new Map(currentState.activeStreams);
let finalCompletedStreams = new Map(currentState.completedStreams);
const storedStream = finalActiveStreams.get(sessionId);
if (storedStream === stream) {
const result: StreamResult = {
sessionId,
status: stream.status,
chunks: stream.chunks,
completedAt: Date.now(),
error: stream.error,
};
finalCompletedStreams.set(sessionId, result);
finalActiveStreams.delete(sessionId);
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
set({
activeStreams: finalActiveStreams,
completedStreams: finalCompletedStreams,
});
if (stream.status === "completed" || stream.status === "error") {
notifyStreamComplete(
currentState.streamCompleteCallbacks,
sessionId,
);
}
}
}
} }
}, },
@@ -409,93 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
set({ streamCompleteCallbacks: cleanedCallbacks }); set({ streamCompleteCallbacks: cleanedCallbacks });
}; };
}, },
setActiveTask: function setActiveTask(sessionId, taskInfo) {
const state = get();
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...taskInfo,
sessionId,
startedAt: Date.now(),
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
getActiveTask: function getActiveTask(sessionId) {
return get().activeTasks.get(sessionId);
},
clearActiveTask: function clearActiveTask(sessionId) {
const state = get();
if (!state.activeTasks.has(sessionId)) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
reconnectToTask: async function reconnectToTask(
sessionId,
taskId,
lastMessageId = INITIAL_STREAM_ID,
onChunk,
) {
const state = get();
const callbacks = state.streamCompleteCallbacks;
// Clean up any existing stream for this session
const {
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
} = cleanupExistingStream(
sessionId,
state.activeStreams,
state.completedStreams,
callbacks,
);
// Create new stream for reconnection
const stream = createActiveStream(sessionId, onChunk);
newActiveStreams.set(sessionId, stream);
set({
activeStreams: newActiveStreams,
completedStreams: newCompletedStreams,
});
try {
await executeTaskReconnect(stream, taskId, lastMessageId);
} finally {
finalizeStream(sessionId, stream, onChunk, get, set);
// Clear active task on completion
if (stream.status === "completed" || stream.status === "error") {
const taskState = get();
if (taskState.activeTasks.has(sessionId)) {
const newActiveTasks = new Map(taskState.activeTasks);
newActiveTasks.delete(sessionId);
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
}
}
}
},
updateTaskLastMessageId: function updateTaskLastMessageId(
sessionId,
lastMessageId,
) {
const state = get();
const task = state.activeTasks.get(sessionId);
if (!task) return;
const newActiveTasks = new Map(state.activeTasks);
newActiveTasks.set(sessionId, {
...task,
lastMessageId,
});
set({ activeTasks: newActiveTasks });
persistTasks(newActiveTasks);
},
})); }));

View File

@@ -4,7 +4,6 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
export interface StreamChunk { export interface StreamChunk {
type: type:
| "stream_start"
| "text_chunk" | "text_chunk"
| "text_ended" | "text_ended"
| "tool_call" | "tool_call"
@@ -16,7 +15,6 @@ export interface StreamChunk {
| "error" | "error"
| "usage" | "usage"
| "stream_end"; | "stream_end";
taskId?: string;
timestamp?: string; timestamp?: string;
content?: string; content?: string;
message?: string; message?: string;
@@ -43,7 +41,7 @@ export interface StreamChunk {
} }
export type VercelStreamChunk = export type VercelStreamChunk =
| { type: "start"; messageId: string; taskId?: string } | { type: "start"; messageId: string }
| { type: "finish" } | { type: "finish" }
| { type: "text-start"; id: string } | { type: "text-start"; id: string }
| { type: "text-delta"; id: string; delta: string } | { type: "text-delta"; id: string; delta: string }
@@ -94,70 +92,3 @@ export interface StreamResult {
} }
export type StreamCompleteCallback = (sessionId: string) => void; export type StreamCompleteCallback = (sessionId: string) => void;
// Type guards for message types
/**
* Check if a message has a toolId property.
*/
export function hasToolId<T extends { type: string }>(
msg: T,
): msg is T & { toolId: string } {
return (
"toolId" in msg &&
typeof (msg as Record<string, unknown>).toolId === "string"
);
}
/**
* Check if a message has an operationId property.
*/
export function hasOperationId<T extends { type: string }>(
msg: T,
): msg is T & { operationId: string } {
return (
"operationId" in msg &&
typeof (msg as Record<string, unknown>).operationId === "string"
);
}
/**
* Check if a message has a toolCallId property.
*/
export function hasToolCallId<T extends { type: string }>(
msg: T,
): msg is T & { toolCallId: string } {
return (
"toolCallId" in msg &&
typeof (msg as Record<string, unknown>).toolCallId === "string"
);
}
/**
* Check if a message is an operation message type.
*/
export function isOperationMessage<T extends { type: string }>(
msg: T,
): msg is T & {
type: "operation_started" | "operation_pending" | "operation_in_progress";
} {
return (
msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
);
}
/**
* Get the tool ID from a message if available.
* Checks toolId, operationId, and toolCallId properties.
*/
export function getToolIdFromMessage<T extends { type: string }>(
msg: T,
): string | undefined {
const record = msg as Record<string, unknown>;
if (typeof record.toolId === "string") return record.toolId;
if (typeof record.operationId === "string") return record.operationId;
if (typeof record.toolCallId === "string") return record.toolCallId;
return undefined;
}

View File

@@ -17,13 +17,6 @@ export interface ChatContainerProps {
className?: string; className?: string;
onStreamingChange?: (isStreaming: boolean) => void; onStreamingChange?: (isStreaming: boolean) => void;
onOperationStarted?: () => void; onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
operationId: string;
toolName: string;
};
} }
export function ChatContainer({ export function ChatContainer({
@@ -33,7 +26,6 @@ export function ChatContainer({
className, className,
onStreamingChange, onStreamingChange,
onOperationStarted, onOperationStarted,
activeStream,
}: ChatContainerProps) { }: ChatContainerProps) {
const { const {
messages, messages,
@@ -49,7 +41,6 @@ export function ChatContainer({
initialMessages, initialMessages,
initialPrompt, initialPrompt,
onOperationStarted, onOperationStarted,
activeStream,
}); });
useEffect(() => { useEffect(() => {

View File

@@ -2,7 +2,6 @@ import { toast } from "sonner";
import type { StreamChunk } from "../../chat-types"; import type { StreamChunk } from "../../chat-types";
import type { HandlerDependencies } from "./handlers"; import type { HandlerDependencies } from "./handlers";
import { import {
getErrorDisplayMessage,
handleError, handleError,
handleLoginNeeded, handleLoginNeeded,
handleStreamEnd, handleStreamEnd,
@@ -25,22 +24,16 @@ export function createStreamEventDispatcher(
chunk.type === "need_login" || chunk.type === "need_login" ||
chunk.type === "error" chunk.type === "error"
) { ) {
if (!deps.hasResponseRef.current) {
console.info("[ChatStream] First response chunk:", {
type: chunk.type,
sessionId: deps.sessionId,
});
}
deps.hasResponseRef.current = true; deps.hasResponseRef.current = true;
} }
switch (chunk.type) { switch (chunk.type) {
case "stream_start":
// Store task ID for SSE reconnection
if (chunk.taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId: chunk.taskId,
operationId: chunk.taskId,
toolName: "chat",
toolCallId: "chat_stream",
});
}
break;
case "text_chunk": case "text_chunk":
handleTextChunk(chunk, deps); handleTextChunk(chunk, deps);
break; break;
@@ -63,7 +56,11 @@ export function createStreamEventDispatcher(
break; break;
case "stream_end": case "stream_end":
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk console.info("[ChatStream] Stream ended:", {
sessionId: deps.sessionId,
hasResponse: deps.hasResponseRef.current,
chunkCount: deps.streamingChunksRef.current.length,
});
handleStreamEnd(chunk, deps); handleStreamEnd(chunk, deps);
break; break;
@@ -73,7 +70,7 @@ export function createStreamEventDispatcher(
// Show toast at dispatcher level to avoid circular dependencies // Show toast at dispatcher level to avoid circular dependencies
if (!isRegionBlocked) { if (!isRegionBlocked) {
toast.error("Chat Error", { toast.error("Chat Error", {
description: getErrorDisplayMessage(chunk), description: chunk.message || chunk.content || "An error occurred",
}); });
} }
break; break;

View File

@@ -18,19 +18,11 @@ export interface HandlerDependencies {
setStreamingChunks: Dispatch<SetStateAction<string[]>>; setStreamingChunks: Dispatch<SetStateAction<string[]>>;
streamingChunksRef: MutableRefObject<string[]>; streamingChunksRef: MutableRefObject<string[]>;
hasResponseRef: MutableRefObject<boolean>; hasResponseRef: MutableRefObject<boolean>;
textFinalizedRef: MutableRefObject<boolean>;
streamEndedRef: MutableRefObject<boolean>;
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>; setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>; setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>; setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
sessionId: string; sessionId: string;
onOperationStarted?: () => void; onOperationStarted?: () => void;
onActiveTaskStarted?: (taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) => void;
} }
export function isRegionBlockedError(chunk: StreamChunk): boolean { export function isRegionBlockedError(chunk: StreamChunk): boolean {
@@ -40,25 +32,6 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean {
return message.toLowerCase().includes("not available in your region"); return message.toLowerCase().includes("not available in your region");
} }
export function getUserFriendlyErrorMessage(
code: string | undefined,
): string | undefined {
switch (code) {
case "TASK_EXPIRED":
return "This operation has expired. Please try again.";
case "TASK_NOT_FOUND":
return "Could not find the requested operation.";
case "ACCESS_DENIED":
return "You do not have access to this operation.";
case "QUEUE_OVERFLOW":
return "Connection was interrupted. Please refresh to continue.";
case "MODEL_NOT_AVAILABLE_REGION":
return "This model is not available in your region.";
default:
return undefined;
}
}
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) { export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
if (!chunk.content) return; if (!chunk.content) return;
deps.setHasTextChunks(true); deps.setHasTextChunks(true);
@@ -73,15 +46,10 @@ export function handleTextEnded(
_chunk: StreamChunk, _chunk: StreamChunk,
deps: HandlerDependencies, deps: HandlerDependencies,
) { ) {
if (deps.textFinalizedRef.current) {
return;
}
const completedText = deps.streamingChunksRef.current.join(""); const completedText = deps.streamingChunksRef.current.join("");
if (completedText.trim()) { if (completedText.trim()) {
deps.textFinalizedRef.current = true;
deps.setMessages((prev) => { deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some( const exists = prev.some(
(msg) => (msg) =>
msg.type === "message" && msg.type === "message" &&
@@ -108,14 +76,9 @@ export function handleToolCallStart(
chunk: StreamChunk, chunk: StreamChunk,
deps: HandlerDependencies, deps: HandlerDependencies,
) { ) {
// Use deterministic fallback instead of Date.now() to ensure same ID on replay
const toolId =
chunk.tool_id ||
`tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`;
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = { const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
type: "tool_call", type: "tool_call",
toolId, toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
toolName: chunk.tool_name || "Executing", toolName: chunk.tool_name || "Executing",
arguments: chunk.arguments || {}, arguments: chunk.arguments || {},
timestamp: new Date(), timestamp: new Date(),
@@ -148,29 +111,6 @@ export function handleToolCallStart(
deps.setMessages(updateToolCallMessages); deps.setMessages(updateToolCallMessages);
} }
const TOOL_RESPONSE_TYPES = new Set([
"tool_response",
"operation_started",
"operation_pending",
"operation_in_progress",
"execution_started",
"agent_carousel",
"clarification_needed",
]);
function hasResponseForTool(
messages: ChatMessageData[],
toolId: string,
): boolean {
return messages.some((msg) => {
if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false;
const msgToolId =
(msg as { toolId?: string }).toolId ||
(msg as { toolCallId?: string }).toolCallId;
return msgToolId === toolId;
});
}
export function handleToolResponse( export function handleToolResponse(
chunk: StreamChunk, chunk: StreamChunk,
deps: HandlerDependencies, deps: HandlerDependencies,
@@ -212,49 +152,31 @@ export function handleToolResponse(
) { ) {
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
if (inputsMessage) { if (inputsMessage) {
deps.setMessages((prev) => { deps.setMessages((prev) => [...prev, inputsMessage]);
// Check for duplicate inputs_needed message
const exists = prev.some((msg) => msg.type === "inputs_needed");
if (exists) return prev;
return [...prev, inputsMessage];
});
} }
const credentialsMessage = extractCredentialsNeeded( const credentialsMessage = extractCredentialsNeeded(
parsedResult, parsedResult,
chunk.tool_name, chunk.tool_name,
); );
if (credentialsMessage) { if (credentialsMessage) {
deps.setMessages((prev) => { deps.setMessages((prev) => [...prev, credentialsMessage]);
// Check for duplicate credentials_needed message
const exists = prev.some((msg) => msg.type === "credentials_needed");
if (exists) return prev;
return [...prev, credentialsMessage];
});
} }
} }
return; return;
} }
// Trigger polling when operation_started is received
if (responseMessage.type === "operation_started") { if (responseMessage.type === "operation_started") {
deps.onOperationStarted?.(); deps.onOperationStarted?.();
const taskId = (responseMessage as { taskId?: string }).taskId;
if (taskId && deps.onActiveTaskStarted) {
deps.onActiveTaskStarted({
taskId,
operationId:
(responseMessage as { operationId?: string }).operationId || "",
toolName: (responseMessage as { toolName?: string }).toolName || "",
toolCallId: (responseMessage as { toolId?: string }).toolId || "",
});
}
} }
deps.setMessages((prev) => { deps.setMessages((prev) => {
const toolCallIndex = prev.findIndex( const toolCallIndex = prev.findIndex(
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
); );
if (hasResponseForTool(prev, chunk.tool_id!)) { const hasResponse = prev.some(
return prev; (msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
} );
if (hasResponse) return prev;
if (toolCallIndex !== -1) { if (toolCallIndex !== -1) {
const newMessages = [...prev]; const newMessages = [...prev];
newMessages.splice(toolCallIndex + 1, 0, responseMessage); newMessages.splice(toolCallIndex + 1, 0, responseMessage);
@@ -276,48 +198,28 @@ export function handleLoginNeeded(
agentInfo: chunk.agent_info, agentInfo: chunk.agent_info,
timestamp: new Date(), timestamp: new Date(),
}; };
deps.setMessages((prev) => { deps.setMessages((prev) => [...prev, loginNeededMessage]);
// Check for duplicate login_needed message
const exists = prev.some((msg) => msg.type === "login_needed");
if (exists) return prev;
return [...prev, loginNeededMessage];
});
} }
export function handleStreamEnd( export function handleStreamEnd(
_chunk: StreamChunk, _chunk: StreamChunk,
deps: HandlerDependencies, deps: HandlerDependencies,
) { ) {
if (deps.streamEndedRef.current) {
return;
}
deps.streamEndedRef.current = true;
const completedContent = deps.streamingChunksRef.current.join(""); const completedContent = deps.streamingChunksRef.current.join("");
if (!completedContent.trim() && !deps.hasResponseRef.current) { if (!completedContent.trim() && !deps.hasResponseRef.current) {
deps.setMessages((prev) => { deps.setMessages((prev) => [
const exists = prev.some( ...prev,
(msg) => {
msg.type === "message" && type: "message",
msg.role === "assistant" && role: "assistant",
msg.content === "No response received. Please try again.", content: "No response received. Please try again.",
); timestamp: new Date(),
if (exists) return prev; },
return [ ]);
...prev,
{
type: "message",
role: "assistant",
content: "No response received. Please try again.",
timestamp: new Date(),
},
];
});
} }
if (completedContent.trim() && !deps.textFinalizedRef.current) { if (completedContent.trim()) {
deps.textFinalizedRef.current = true;
deps.setMessages((prev) => { deps.setMessages((prev) => {
// Check if this exact message already exists to prevent duplicates
const exists = prev.some( const exists = prev.some(
(msg) => (msg) =>
msg.type === "message" && msg.type === "message" &&
@@ -342,6 +244,8 @@ export function handleStreamEnd(
} }
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
const errorMessage = chunk.message || chunk.content || "An error occurred";
console.error("Stream error:", errorMessage);
if (isRegionBlockedError(chunk)) { if (isRegionBlockedError(chunk)) {
deps.setIsRegionBlockedModalOpen(true); deps.setIsRegionBlockedModalOpen(true);
} }
@@ -349,14 +253,4 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
deps.setHasTextChunks(false); deps.setHasTextChunks(false);
deps.setStreamingChunks([]); deps.setStreamingChunks([]);
deps.streamingChunksRef.current = []; deps.streamingChunksRef.current = [];
deps.textFinalizedRef.current = false;
deps.streamEndedRef.current = true;
}
export function getErrorDisplayMessage(chunk: StreamChunk): string {
const friendlyMessage = getUserFriendlyErrorMessage(chunk.code);
if (friendlyMessage) {
return friendlyMessage;
}
return chunk.message || chunk.content || "An error occurred";
} }

View File

@@ -349,7 +349,6 @@ export function parseToolResponse(
toolName: (parsedResult.tool_name as string) || toolName, toolName: (parsedResult.tool_name as string) || toolName,
toolId, toolId,
operationId: (parsedResult.operation_id as string) || "", operationId: (parsedResult.operation_id as string) || "",
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
message: message:
(parsedResult.message as string) || (parsedResult.message as string) ||
"Operation started. You can close this tab.", "Operation started. You can close this tab.",

View File

@@ -1,17 +1,10 @@
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
import { useEffect, useMemo, useRef, useState } from "react"; import { useEffect, useMemo, useRef, useState } from "react";
import { INITIAL_STREAM_ID } from "../../chat-constants";
import { useChatStore } from "../../chat-store"; import { useChatStore } from "../../chat-store";
import { toast } from "sonner"; import { toast } from "sonner";
import { useChatStream } from "../../useChatStream"; import { useChatStream } from "../../useChatStream";
import { usePageContext } from "../../usePageContext"; import { usePageContext } from "../../usePageContext";
import type { ChatMessageData } from "../ChatMessage/useChatMessage"; import type { ChatMessageData } from "../ChatMessage/useChatMessage";
import {
getToolIdFromMessage,
hasToolId,
isOperationMessage,
type StreamChunk,
} from "../../chat-types";
import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
import { import {
createUserMessage, createUserMessage,
@@ -21,13 +14,6 @@ import {
processInitialMessages, processInitialMessages,
} from "./helpers"; } from "./helpers";
const TOOL_RESULT_TYPES = new Set([
"tool_response",
"agent_carousel",
"execution_started",
"clarification_needed",
]);
// Helper to generate deduplication key for a message // Helper to generate deduplication key for a message
function getMessageKey(msg: ChatMessageData): string { function getMessageKey(msg: ChatMessageData): string {
if (msg.type === "message") { if (msg.type === "message") {
@@ -37,18 +23,14 @@ function getMessageKey(msg: ChatMessageData): string {
return `msg:${msg.role}:${msg.content}`; return `msg:${msg.role}:${msg.content}`;
} else if (msg.type === "tool_call") { } else if (msg.type === "tool_call") {
return `toolcall:${msg.toolId}`; return `toolcall:${msg.toolId}`;
} else if (TOOL_RESULT_TYPES.has(msg.type)) { } else if (msg.type === "tool_response") {
// Unified key for all tool result types - same toolId with different types return `toolresponse:${(msg as any).toolId}`;
// (tool_response vs agent_carousel) should deduplicate to the same key } else if (
const toolId = getToolIdFromMessage(msg); msg.type === "operation_started" ||
// If no toolId, fall back to content-based key to avoid empty key collisions msg.type === "operation_pending" ||
if (!toolId) { msg.type === "operation_in_progress"
return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`; ) {
} return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
return `toolresult:${toolId}`;
} else if (isOperationMessage(msg)) {
const toolId = getToolIdFromMessage(msg) || "";
return `op:${toolId}:${msg.toolName}`;
} else { } else {
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
} }
@@ -59,13 +41,6 @@ interface Args {
initialMessages: SessionDetailResponse["messages"]; initialMessages: SessionDetailResponse["messages"];
initialPrompt?: string; initialPrompt?: string;
onOperationStarted?: () => void; onOperationStarted?: () => void;
/** Active stream info from the server for reconnection */
activeStream?: {
taskId: string;
lastMessageId: string;
operationId: string;
toolName: string;
};
} }
export function useChatContainer({ export function useChatContainer({
@@ -73,7 +48,6 @@ export function useChatContainer({
initialMessages, initialMessages,
initialPrompt, initialPrompt,
onOperationStarted, onOperationStarted,
activeStream,
}: Args) { }: Args) {
const [messages, setMessages] = useState<ChatMessageData[]>([]); const [messages, setMessages] = useState<ChatMessageData[]>([]);
const [streamingChunks, setStreamingChunks] = useState<string[]>([]); const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
@@ -83,8 +57,6 @@ export function useChatContainer({
useState(false); useState(false);
const hasResponseRef = useRef(false); const hasResponseRef = useRef(false);
const streamingChunksRef = useRef<string[]>([]); const streamingChunksRef = useRef<string[]>([]);
const textFinalizedRef = useRef(false);
const streamEndedRef = useRef(false);
const previousSessionIdRef = useRef<string | null>(null); const previousSessionIdRef = useRef<string | null>(null);
const { const {
error, error,
@@ -93,182 +65,44 @@ export function useChatContainer({
} = useChatStream(); } = useChatStream();
const activeStreams = useChatStore((s) => s.activeStreams); const activeStreams = useChatStore((s) => s.activeStreams);
const subscribeToStream = useChatStore((s) => s.subscribeToStream); const subscribeToStream = useChatStore((s) => s.subscribeToStream);
const setActiveTask = useChatStore((s) => s.setActiveTask);
const getActiveTask = useChatStore((s) => s.getActiveTask);
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
const isStreaming = isStreamingInitiated || hasTextChunks; const isStreaming = isStreamingInitiated || hasTextChunks;
// Track whether we've already connected to this activeStream to avoid duplicate connections
const connectedActiveStreamRef = useRef<string | null>(null);
// Track if component is mounted to prevent state updates after unmount
const isMountedRef = useRef(true);
// Track current dispatcher to prevent multiple dispatchers from adding messages
const currentDispatcherIdRef = useRef(0);
// Set mounted flag - reset on every mount, cleanup on unmount
useEffect(function trackMountedState() {
isMountedRef.current = true;
return function cleanup() {
isMountedRef.current = false;
};
}, []);
// Callback to store active task info for SSE reconnection
function handleActiveTaskStarted(taskInfo: {
taskId: string;
operationId: string;
toolName: string;
toolCallId: string;
}) {
if (!sessionId) return;
setActiveTask(sessionId, {
taskId: taskInfo.taskId,
operationId: taskInfo.operationId,
toolName: taskInfo.toolName,
lastMessageId: INITIAL_STREAM_ID,
});
}
// Create dispatcher for stream events - stable reference for current sessionId
// Each dispatcher gets a unique ID to prevent stale dispatchers from updating state
function createDispatcher() {
if (!sessionId) return () => {};
// Increment dispatcher ID - only the most recent dispatcher should update state
const dispatcherId = ++currentDispatcherIdRef.current;
const baseDispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
textFinalizedRef,
streamEndedRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
onActiveTaskStarted: handleActiveTaskStarted,
});
// Wrap dispatcher to check if it's still the current one
return function guardedDispatcher(chunk: StreamChunk) {
// Skip if component unmounted or this is a stale dispatcher
if (!isMountedRef.current) {
return;
}
if (dispatcherId !== currentDispatcherIdRef.current) {
return;
}
baseDispatcher(chunk);
};
}
useEffect( useEffect(
function handleSessionChange() { function handleSessionChange() {
const isSessionChange = sessionId !== previousSessionIdRef.current; if (sessionId === previousSessionIdRef.current) return;
// Handle session change - reset state const prevSession = previousSessionIdRef.current;
if (isSessionChange) { if (prevSession) {
const prevSession = previousSessionIdRef.current; stopStreaming(prevSession);
if (prevSession) {
stopStreaming(prevSession);
}
previousSessionIdRef.current = sessionId;
connectedActiveStreamRef.current = null;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
textFinalizedRef.current = false;
streamEndedRef.current = false;
} }
previousSessionIdRef.current = sessionId;
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
setIsStreamingInitiated(false);
hasResponseRef.current = false;
if (!sessionId) return; if (!sessionId) return;
// Priority 1: Check if server told us there's an active stream (most authoritative) const activeStream = activeStreams.get(sessionId);
if (activeStream) { if (!activeStream || activeStream.status !== "streaming") return;
const streamKey = `${sessionId}:${activeStream.taskId}`;
if (connectedActiveStreamRef.current === streamKey) { const dispatcher = createStreamEventDispatcher({
return; setHasTextChunks,
} setStreamingChunks,
streamingChunksRef,
// Skip if there's already an active stream for this session in the store hasResponseRef,
const existingStream = activeStreams.get(sessionId); setMessages,
if (existingStream && existingStream.status === "streaming") { setIsRegionBlockedModalOpen,
connectedActiveStreamRef.current = streamKey; sessionId,
return; setIsStreamingInitiated,
} onOperationStarted,
});
connectedActiveStreamRef.current = streamKey;
// Clear all state before reconnection to prevent duplicates
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
textFinalizedRef.current = false;
streamEndedRef.current = false;
hasResponseRef.current = false;
setIsStreamingInitiated(true);
setActiveTask(sessionId, {
taskId: activeStream.taskId,
operationId: activeStream.operationId,
toolName: activeStream.toolName,
lastMessageId: activeStream.lastMessageId,
});
reconnectToTask(
sessionId,
activeStream.taskId,
activeStream.lastMessageId,
createDispatcher(),
);
// Don't return cleanup here - the guarded dispatcher handles stale events
// and the stream will complete naturally. Cleanup would prematurely stop
// the stream when effect re-runs due to activeStreams changing.
return;
}
// Only check localStorage/in-memory on session change
if (!isSessionChange) return;
// Priority 2: Check localStorage for active task
const activeTask = getActiveTask(sessionId);
if (activeTask) {
// Clear all state before reconnection to prevent duplicates
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
setMessages([]);
setStreamingChunks([]);
streamingChunksRef.current = [];
setHasTextChunks(false);
textFinalizedRef.current = false;
streamEndedRef.current = false;
hasResponseRef.current = false;
setIsStreamingInitiated(true);
reconnectToTask(
sessionId,
activeTask.taskId,
activeTask.lastMessageId,
createDispatcher(),
);
// Don't return cleanup here - the guarded dispatcher handles stale events
return;
}
// Priority 3: Check for an in-memory active stream (same-tab scenario)
const inMemoryStream = activeStreams.get(sessionId);
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
return;
}
setIsStreamingInitiated(true); setIsStreamingInitiated(true);
const skipReplay = initialMessages.length > 0; const skipReplay = initialMessages.length > 0;
return subscribeToStream(sessionId, createDispatcher(), skipReplay); return subscribeToStream(sessionId, dispatcher, skipReplay);
}, },
[ [
sessionId, sessionId,
@@ -276,10 +110,6 @@ export function useChatContainer({
activeStreams, activeStreams,
subscribeToStream, subscribeToStream,
onOperationStarted, onOperationStarted,
getActiveTask,
reconnectToTask,
activeStream,
setActiveTask,
], ],
); );
@@ -294,7 +124,7 @@ export function useChatContainer({
msg.type === "agent_carousel" || msg.type === "agent_carousel" ||
msg.type === "execution_started" msg.type === "execution_started"
) { ) {
const toolId = hasToolId(msg) ? msg.toolId : undefined; const toolId = (msg as any).toolId;
if (toolId) { if (toolId) {
ids.add(toolId); ids.add(toolId);
} }
@@ -311,8 +141,12 @@ export function useChatContainer({
setMessages((prev) => { setMessages((prev) => {
const filtered = prev.filter((msg) => { const filtered = prev.filter((msg) => {
if (isOperationMessage(msg)) { if (
const toolId = getToolIdFromMessage(msg); msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) { if (toolId && completedToolIds.has(toolId)) {
return false; // Remove - operation completed return false; // Remove - operation completed
} }
@@ -340,8 +174,12 @@ export function useChatContainer({
// Filter local messages: remove duplicates and completed operation messages // Filter local messages: remove duplicates and completed operation messages
const newLocalMessages = messages.filter((msg) => { const newLocalMessages = messages.filter((msg) => {
// Remove operation messages for completed tools // Remove operation messages for completed tools
if (isOperationMessage(msg)) { if (
const toolId = getToolIdFromMessage(msg); msg.type === "operation_started" ||
msg.type === "operation_pending" ||
msg.type === "operation_in_progress"
) {
const toolId = (msg as any).toolId || (msg as any).toolCallId;
if (toolId && completedToolIds.has(toolId)) { if (toolId && completedToolIds.has(toolId)) {
return false; return false;
} }
@@ -352,70 +190,7 @@ export function useChatContainer({
}); });
// Server messages first (correct order), then new local messages // Server messages first (correct order), then new local messages
const combined = [...processedInitial, ...newLocalMessages]; return [...processedInitial, ...newLocalMessages];
// Post-processing: Remove duplicate assistant messages that can occur during
// race conditions (e.g., rapid screen switching during SSE reconnection).
// Two assistant messages are considered duplicates if:
// - They are both text messages with role "assistant"
// - One message's content starts with the other's content (partial vs complete)
// - Or they have very similar content (>80% overlap at the start)
const deduplicated: ChatMessageData[] = [];
for (let i = 0; i < combined.length; i++) {
const current = combined[i];
// Check if this is an assistant text message
if (current.type !== "message" || current.role !== "assistant") {
deduplicated.push(current);
continue;
}
// Look for duplicate assistant messages in the rest of the array
let dominated = false;
for (let j = 0; j < combined.length; j++) {
if (i === j) continue;
const other = combined[j];
if (other.type !== "message" || other.role !== "assistant") continue;
const currentContent = current.content || "";
const otherContent = other.content || "";
// Skip empty messages
if (!currentContent.trim() || !otherContent.trim()) continue;
// Check if current is a prefix of other (current is incomplete version)
if (
otherContent.length > currentContent.length &&
otherContent.startsWith(currentContent.slice(0, 100))
) {
// Current is a shorter/incomplete version of other - skip it
dominated = true;
break;
}
// Check if messages are nearly identical (within a small difference)
// This catches cases where content differs only slightly
const minLen = Math.min(currentContent.length, otherContent.length);
const compareLen = Math.min(minLen, 200); // Compare first 200 chars
if (
compareLen > 50 &&
currentContent.slice(0, compareLen) ===
otherContent.slice(0, compareLen)
) {
// Same prefix - keep the longer one
if (otherContent.length > currentContent.length) {
dominated = true;
break;
}
}
}
if (!dominated) {
deduplicated.push(current);
}
}
return deduplicated;
}, [initialMessages, messages, completedToolIds]); }, [initialMessages, messages, completedToolIds]);
async function sendMessage( async function sendMessage(
@@ -423,8 +198,10 @@ export function useChatContainer({
isUserMessage: boolean = true, isUserMessage: boolean = true,
context?: { url: string; content: string }, context?: { url: string; content: string },
) { ) {
if (!sessionId) return; if (!sessionId) {
console.error("[useChatContainer] Cannot send message: no session ID");
return;
}
setIsRegionBlockedModalOpen(false); setIsRegionBlockedModalOpen(false);
if (isUserMessage) { if (isUserMessage) {
const userMessage = createUserMessage(content); const userMessage = createUserMessage(content);
@@ -437,19 +214,31 @@ export function useChatContainer({
setHasTextChunks(false); setHasTextChunks(false);
setIsStreamingInitiated(true); setIsStreamingInitiated(true);
hasResponseRef.current = false; hasResponseRef.current = false;
textFinalizedRef.current = false;
streamEndedRef.current = false; const dispatcher = createStreamEventDispatcher({
setHasTextChunks,
setStreamingChunks,
streamingChunksRef,
hasResponseRef,
setMessages,
setIsRegionBlockedModalOpen,
sessionId,
setIsStreamingInitiated,
onOperationStarted,
});
try { try {
await sendStreamMessage( await sendStreamMessage(
sessionId, sessionId,
content, content,
createDispatcher(), dispatcher,
isUserMessage, isUserMessage,
context, context,
); );
} catch (err) { } catch (err) {
console.error("[useChatContainer] Failed to send message:", err);
setIsStreamingInitiated(false); setIsStreamingInitiated(false);
if (err instanceof Error && err.name === "AbortError") return; if (err instanceof Error && err.name === "AbortError") return;
const errorMessage = const errorMessage =

View File

@@ -111,7 +111,6 @@ export type ChatMessageData =
toolName: string; toolName: string;
toolId: string; toolId: string;
operationId: string; operationId: string;
taskId?: string; // For SSE reconnection
message: string; message: string;
timestamp?: string | Date; timestamp?: string | Date;
} }

View File

@@ -31,6 +31,11 @@ export function MessageList({
isStreaming, isStreaming,
}); });
/**
* Keeps this for debugging purposes 💆🏽
*/
console.log(messages);
return ( return (
<div className="relative flex min-h-0 flex-1 flex-col"> <div className="relative flex min-h-0 flex-1 flex-col">
{/* Top fade shadow */} {/* Top fade shadow */}

View File

@@ -1,4 +1,3 @@
import { INITIAL_STREAM_ID } from "./chat-constants";
import type { import type {
ActiveStream, ActiveStream,
StreamChunk, StreamChunk,
@@ -11,14 +10,8 @@ import {
parseSSELine, parseSSELine,
} from "./stream-utils"; } from "./stream-utils";
function notifySubscribers( function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
stream: ActiveStream, stream.chunks.push(chunk);
chunk: StreamChunk,
skipStore = false,
) {
if (!skipStore) {
stream.chunks.push(chunk);
}
for (const callback of stream.onChunkCallbacks) { for (const callback of stream.onChunkCallbacks) {
try { try {
callback(chunk); callback(chunk);
@@ -28,114 +21,36 @@ function notifySubscribers(
} }
} }
interface StreamExecutionOptions { export async function executeStream(
stream: ActiveStream; stream: ActiveStream,
mode: "new" | "reconnect"; message: string,
message?: string; isUserMessage: boolean,
isUserMessage?: boolean; context?: { url: string; content: string },
context?: { url: string; content: string }; retryCount: number = 0,
taskId?: string;
lastMessageId?: string;
retryCount?: number;
}
async function executeStreamInternal(
options: StreamExecutionOptions,
): Promise<void> { ): Promise<void> {
const {
stream,
mode,
message,
isUserMessage,
context,
taskId,
lastMessageId = INITIAL_STREAM_ID,
retryCount = 0,
} = options;
const { sessionId, abortController } = stream; const { sessionId, abortController } = stream;
const isReconnect = mode === "reconnect";
if (isReconnect) {
if (!taskId) {
throw new Error("taskId is required for reconnect mode");
}
if (lastMessageId === null || lastMessageId === undefined) {
throw new Error("lastMessageId is required for reconnect mode");
}
} else {
if (!message) {
throw new Error("message is required for new stream mode");
}
if (isUserMessage === undefined) {
throw new Error("isUserMessage is required for new stream mode");
}
}
try { try {
let url: string; const url = `/api/chat/sessions/${sessionId}/stream`;
let fetchOptions: RequestInit; const body = JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
});
if (isReconnect) { const response = await fetch(url, {
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`; method: "POST",
fetchOptions = { headers: {
method: "GET", "Content-Type": "application/json",
headers: { Accept: "text/event-stream",
Accept: "text/event-stream", },
}, body,
signal: abortController.signal, signal: abortController.signal,
}; });
} else {
url = `/api/chat/sessions/${sessionId}/stream`;
fetchOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
},
body: JSON.stringify({
message,
is_user_message: isUserMessage,
context: context || null,
}),
signal: abortController.signal,
};
}
const response = await fetch(url, fetchOptions);
if (!response.ok) { if (!response.ok) {
const errorText = await response.text(); const errorText = await response.text();
let errorCode: string | undefined; throw new Error(errorText || `HTTP ${response.status}`);
let errorMessage = errorText || `HTTP ${response.status}`;
try {
const parsed = JSON.parse(errorText);
if (parsed.detail) {
const detail =
typeof parsed.detail === "string"
? parsed.detail
: parsed.detail.message || JSON.stringify(parsed.detail);
errorMessage = detail;
errorCode =
typeof parsed.detail === "object" ? parsed.detail.code : undefined;
}
} catch {}
const isPermanentError =
isReconnect &&
(response.status === 404 ||
response.status === 403 ||
response.status === 410);
const error = new Error(errorMessage) as Error & {
status?: number;
isPermanent?: boolean;
taskErrorCode?: string;
};
error.status = response.status;
error.isPermanent = isPermanentError;
error.taskErrorCode = errorCode;
throw error;
} }
if (!response.body) { if (!response.body) {
@@ -189,7 +104,9 @@ async function executeStreamInternal(
); );
return; return;
} }
} catch {} } catch (err) {
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
}
} }
} }
} }
@@ -200,17 +117,19 @@ async function executeStreamInternal(
return; return;
} }
const isPermanentError = if (retryCount < MAX_RETRIES) {
err instanceof Error &&
(err as Error & { isPermanent?: boolean }).isPermanent;
if (!isPermanentError && retryCount < MAX_RETRIES) {
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount); const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
console.log(
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
);
await new Promise((resolve) => setTimeout(resolve, retryDelay)); await new Promise((resolve) => setTimeout(resolve, retryDelay));
return executeStreamInternal({ return executeStream(
...options, stream,
retryCount: retryCount + 1, message,
}); isUserMessage,
context,
retryCount + 1,
);
} }
stream.status = "error"; stream.status = "error";
@@ -221,35 +140,3 @@ async function executeStreamInternal(
}); });
} }
} }
export async function executeStream(
stream: ActiveStream,
message: string,
isUserMessage: boolean,
context?: { url: string; content: string },
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "new",
message,
isUserMessage,
context,
retryCount,
});
}
export async function executeTaskReconnect(
stream: ActiveStream,
taskId: string,
lastMessageId: string = INITIAL_STREAM_ID,
retryCount: number = 0,
): Promise<void> {
return executeStreamInternal({
stream,
mode: "reconnect",
taskId,
lastMessageId,
retryCount,
});
}

View File

@@ -28,7 +28,6 @@ export function normalizeStreamChunk(
switch (chunk.type) { switch (chunk.type) {
case "text-delta": case "text-delta":
// Vercel AI SDK sends "delta" for text content
return { type: "text_chunk", content: chunk.delta }; return { type: "text_chunk", content: chunk.delta };
case "text-end": case "text-end":
return { type: "text_ended" }; return { type: "text_ended" };
@@ -64,10 +63,6 @@ export function normalizeStreamChunk(
case "finish": case "finish":
return { type: "stream_end" }; return { type: "stream_end" };
case "start": case "start":
// Start event with optional taskId for reconnection
return chunk.taskId
? { type: "stream_start", taskId: chunk.taskId }
: null;
case "text-start": case "text-start":
return null; return null;
case "tool-input-start": case "tool-input-start":

View File

@@ -1,6 +1,7 @@
"use client"; "use client";
import { IconLaptop } from "@/components/__legacy__/ui/icons"; import { IconLaptop } from "@/components/__legacy__/ui/icons";
import { getHomepageRoute } from "@/lib/constants";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr"; import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
@@ -23,11 +24,11 @@ interface Props {
export function NavbarLink({ name, href }: Props) { export function NavbarLink({ name, href }: Props) {
const pathname = usePathname(); const pathname = usePathname();
const isChatEnabled = useGetFlag(Flag.CHAT); const isChatEnabled = useGetFlag(Flag.CHAT);
const expectedHomeRoute = isChatEnabled ? "/copilot" : "/library"; const homepageRoute = getHomepageRoute(isChatEnabled);
const isActive = const isActive =
href === expectedHomeRoute href === homepageRoute
? pathname === "/" || pathname.startsWith(expectedHomeRoute) ? pathname === "/" || pathname.startsWith(homepageRoute)
: pathname.includes(href); : pathname.includes(href);
return ( return (

View File

@@ -66,7 +66,7 @@ export default function useAgentGraph(
>(null); >(null);
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]); const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]); const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS) as string[]; const betaBlocks = useGetFlag(Flag.BETA_BLOCKS);
// Filter blocks based on beta flags // Filter blocks based on beta flags
const availableBlocks = useMemo(() => { const availableBlocks = useMemo(() => {

View File

@@ -11,3 +11,10 @@ export const API_KEY_HEADER_NAME = "X-API-Key";
// Layout // Layout
export const NAVBAR_HEIGHT_PX = 60; export const NAVBAR_HEIGHT_PX = 60;
// Routes
export function getHomepageRoute(isChatEnabled?: boolean | null): string {
if (isChatEnabled === true) return "/copilot";
if (isChatEnabled === false) return "/library";
return "/";
}

View File

@@ -1,3 +1,4 @@
import { getHomepageRoute } from "@/lib/constants";
import { environment } from "@/services/environment"; import { environment } from "@/services/environment";
import { Key, storage } from "@/services/storage/local-storage"; import { Key, storage } from "@/services/storage/local-storage";
import { type CookieOptions } from "@supabase/ssr"; import { type CookieOptions } from "@supabase/ssr";
@@ -70,7 +71,7 @@ export function getRedirectPath(
} }
if (isAdminPage(path) && userRole !== "admin") { if (isAdminPage(path) && userRole !== "admin") {
return "/"; return getHomepageRoute();
} }
return null; return null;

View File

@@ -1,3 +1,4 @@
import { getHomepageRoute } from "@/lib/constants";
import { environment } from "@/services/environment"; import { environment } from "@/services/environment";
import { createServerClient } from "@supabase/ssr"; import { createServerClient } from "@supabase/ssr";
import { NextResponse, type NextRequest } from "next/server"; import { NextResponse, type NextRequest } from "next/server";
@@ -66,7 +67,7 @@ export async function updateSession(request: NextRequest) {
// 2. Check if user is authenticated but lacks admin role when accessing admin pages // 2. Check if user is authenticated but lacks admin role when accessing admin pages
if (user && userRole !== "admin" && isAdminPage(pathname)) { if (user && userRole !== "admin" && isAdminPage(pathname)) {
url.pathname = "/"; url.pathname = getHomepageRoute();
return NextResponse.redirect(url); return NextResponse.redirect(url);
} }

View File

@@ -23,7 +23,9 @@ import {
WebSocketNotification, WebSocketNotification,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { getHomepageRoute } from "@/lib/constants";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import Link from "next/link"; import Link from "next/link";
import { usePathname, useRouter } from "next/navigation"; import { usePathname, useRouter } from "next/navigation";
import { import {
@@ -102,6 +104,8 @@ export default function OnboardingProvider({
const pathname = usePathname(); const pathname = usePathname();
const router = useRouter(); const router = useRouter();
const { isLoggedIn } = useSupabase(); const { isLoggedIn } = useSupabase();
const isChatEnabled = useGetFlag(Flag.CHAT);
const homepageRoute = getHomepageRoute(isChatEnabled);
useOnboardingTimezoneDetection(); useOnboardingTimezoneDetection();
@@ -146,7 +150,7 @@ export default function OnboardingProvider({
if (isOnOnboardingRoute) { if (isOnOnboardingRoute) {
const enabled = await resolveResponse(getV1IsOnboardingEnabled()); const enabled = await resolveResponse(getV1IsOnboardingEnabled());
if (!enabled) { if (!enabled) {
router.push("/"); router.push(homepageRoute);
return; return;
} }
} }
@@ -158,7 +162,7 @@ export default function OnboardingProvider({
isOnOnboardingRoute && isOnOnboardingRoute &&
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname) shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
) { ) {
router.push("/"); router.push(homepageRoute);
} }
} catch (error) { } catch (error) {
console.error("Failed to initialize onboarding:", error); console.error("Failed to initialize onboarding:", error);
@@ -173,7 +177,7 @@ export default function OnboardingProvider({
} }
initializeOnboarding(); initializeOnboarding();
}, [api, isOnOnboardingRoute, router, isLoggedIn, pathname]); }, [api, homepageRoute, isOnOnboardingRoute, router, isLoggedIn, pathname]);
const handleOnboardingNotification = useCallback( const handleOnboardingNotification = useCallback(
(notification: WebSocketNotification) => { (notification: WebSocketNotification) => {

View File

@@ -83,10 +83,6 @@ function getPostHogCredentials() {
}; };
} }
function getLaunchDarklyClientId() {
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
}
function isProductionBuild() { function isProductionBuild() {
return process.env.NODE_ENV === "production"; return process.env.NODE_ENV === "production";
} }
@@ -124,10 +120,7 @@ function isVercelPreview() {
} }
function areFeatureFlagsEnabled() { function areFeatureFlagsEnabled() {
return ( return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled";
process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true" &&
Boolean(process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID)
);
} }
function isPostHogEnabled() { function isPostHogEnabled() {
@@ -150,7 +143,6 @@ export const environment = {
getSupabaseAnonKey, getSupabaseAnonKey,
getPreviewStealingDev, getPreviewStealingDev,
getPostHogCredentials, getPostHogCredentials,
getLaunchDarklyClientId,
// Assertions // Assertions
isServerSide, isServerSide,
isClientSide, isClientSide,

View File

@@ -1,59 +0,0 @@
"use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useLDClient } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation";
import { ReactNode, useEffect, useState } from "react";
import { environment } from "../environment";
import { Flag, useGetFlag } from "./use-get-flag";
interface FeatureFlagRedirectProps {
flag: Flag;
whenDisabled: string;
children: ReactNode;
}
export function FeatureFlagPage({
flag,
whenDisabled,
children,
}: FeatureFlagRedirectProps) {
const [isLoading, setIsLoading] = useState(true);
const router = useRouter();
const flagValue = useGetFlag(flag);
const ldClient = useLDClient();
const ldEnabled = environment.areFeatureFlagsEnabled();
const ldReady = Boolean(ldClient);
const flagEnabled = Boolean(flagValue);
useEffect(() => {
const initialize = async () => {
if (!ldEnabled) {
router.replace(whenDisabled);
setIsLoading(false);
return;
}
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
if (ldEnabled && !ldReady) return;
try {
await ldClient?.waitForInitialization();
if (!flagEnabled) router.replace(whenDisabled);
} catch (error) {
console.error(error);
router.replace(whenDisabled);
} finally {
setIsLoading(false);
}
};
initialize();
}, [ldReady, flagEnabled]);
return isLoading || !flagEnabled ? (
<LoadingSpinner size="large" cover />
) : (
<>{children}</>
);
}

View File

@@ -1,51 +0,0 @@
"use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useLDClient } from "launchdarkly-react-client-sdk";
import { useRouter } from "next/navigation";
import { useEffect } from "react";
import { environment } from "../environment";
import { Flag, useGetFlag } from "./use-get-flag";
interface FeatureFlagRedirectProps {
flag: Flag;
whenEnabled: string;
whenDisabled: string;
}
export function FeatureFlagRedirect({
flag,
whenEnabled,
whenDisabled,
}: FeatureFlagRedirectProps) {
const router = useRouter();
const flagValue = useGetFlag(flag);
const ldEnabled = environment.areFeatureFlagsEnabled();
const ldClient = useLDClient();
const ldReady = Boolean(ldClient);
const flagEnabled = Boolean(flagValue);
useEffect(() => {
const initialize = async () => {
if (!ldEnabled) {
router.replace(whenDisabled);
return;
}
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
if (ldEnabled && !ldReady) return;
try {
await ldClient?.waitForInitialization();
router.replace(flagEnabled ? whenEnabled : whenDisabled);
} catch (error) {
console.error(error);
router.replace(whenDisabled);
}
};
initialize();
}, [ldReady, flagEnabled]);
return <LoadingSpinner size="large" cover />;
}

View File

@@ -1,6 +1,5 @@
"use client"; "use client";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { LDProvider } from "launchdarkly-react-client-sdk"; import { LDProvider } from "launchdarkly-react-client-sdk";
@@ -8,17 +7,17 @@ import type { ReactNode } from "react";
import { useMemo } from "react"; import { useMemo } from "react";
import { environment } from "../environment"; import { environment } from "../environment";
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000; const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
export function LaunchDarklyProvider({ children }: { children: ReactNode }) { export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
const { user, isUserLoading } = useSupabase(); const { user, isUserLoading } = useSupabase();
const envEnabled = environment.areFeatureFlagsEnabled(); const isCloud = environment.isCloud();
const clientId = environment.getLaunchDarklyClientId(); const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
const context = useMemo(() => { const context = useMemo(() => {
if (isUserLoading) return; if (isUserLoading || !user) {
if (!user) {
return { return {
kind: "user" as const, kind: "user" as const,
key: "anonymous", key: "anonymous",
@@ -37,17 +36,15 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
}; };
}, [user, isUserLoading]); }, [user, isUserLoading]);
if (!envEnabled) { if (!isLaunchDarklyConfigured) {
return <>{children}</>; return <>{children}</>;
} }
if (isUserLoading) {
return <LoadingSpinner size="large" cover />;
}
return ( return (
<LDProvider <LDProvider
clientSideID={clientId ?? ""} // Add this key prop. It will be 'anonymous' when logged out,
key={context.key}
clientSideID={clientId}
context={context} context={context}
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS} timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
reactOptions={{ useCamelCaseFlagKeys: false }} reactOptions={{ useCamelCaseFlagKeys: false }}

View File

@@ -1,7 +1,6 @@
"use client"; "use client";
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers"; import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
import { environment } from "@/services/environment";
import { useFlags } from "launchdarkly-react-client-sdk"; import { useFlags } from "launchdarkly-react-client-sdk";
export enum Flag { export enum Flag {
@@ -19,9 +18,24 @@ export enum Flag {
CHAT = "chat", CHAT = "chat",
} }
export type FlagValues = {
[Flag.BETA_BLOCKS]: string[];
[Flag.NEW_BLOCK_MENU]: boolean;
[Flag.NEW_AGENT_RUNS]: boolean;
[Flag.GRAPH_SEARCH]: boolean;
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
[Flag.NEW_FLOW_EDITOR]: boolean;
[Flag.BUILDER_VIEW_SWITCH]: boolean;
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
[Flag.AGENT_FAVORITING]: boolean;
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
[Flag.CHAT]: boolean;
};
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true"; const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
const defaultFlags = { const mockFlags = {
[Flag.BETA_BLOCKS]: [], [Flag.BETA_BLOCKS]: [],
[Flag.NEW_BLOCK_MENU]: false, [Flag.NEW_BLOCK_MENU]: false,
[Flag.NEW_AGENT_RUNS]: false, [Flag.NEW_AGENT_RUNS]: false,
@@ -36,16 +50,17 @@ const defaultFlags = {
[Flag.CHAT]: false, [Flag.CHAT]: false,
}; };
type FlagValues = typeof defaultFlags; export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] {
const currentFlags = useFlags<FlagValues>(); const currentFlags = useFlags<FlagValues>();
const flagValue = currentFlags[flag]; const flagValue = currentFlags[flag];
const areFlagsEnabled = environment.areFeatureFlagsEnabled();
if (!areFlagsEnabled || isPwMockEnabled) { const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
return defaultFlags[flag]; const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
if (!isLaunchDarklyConfigured || isPwMockEnabled) {
return mockFlags[flag];
} }
return flagValue ?? defaultFlags[flag]; return flagValue ?? mockFlags[flag];
} }

View File

@@ -8,7 +8,6 @@
.buildlog/ .buildlog/
.history .history
.svn/ .svn/
.next/
migrate_working_dir/ migrate_working_dir/
# IntelliJ related # IntelliJ related