mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 11:24:57 -05:00
Compare commits
1 Commits
dev
...
fix/schedu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b1f0df58c |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
|
|||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
CLAUDE.local.md
|
CLAUDE.local.md
|
||||||
/autogpt_platform/backend/logs
|
/autogpt_platform/backend/logs
|
||||||
.next
|
|
||||||
@@ -1,368 +0,0 @@
|
|||||||
"""Redis Streams consumer for operation completion messages.
|
|
||||||
|
|
||||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
|
||||||
completion notifications (OperationCompleteMessage) from external services
|
|
||||||
(like Agent Generator) and triggers the appropriate stream registry and
|
|
||||||
chat service updates via process_operation_success/process_operation_failure.
|
|
||||||
|
|
||||||
Why Redis Streams instead of RabbitMQ?
|
|
||||||
--------------------------------------
|
|
||||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
|
||||||
queue), Redis Streams was chosen for chat completion notifications because:
|
|
||||||
|
|
||||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
|
||||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
|
||||||
Streams for completion notifications keeps all chat streaming infrastructure
|
|
||||||
in one system, simplifying operations and reducing cross-system coordination.
|
|
||||||
|
|
||||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
|
||||||
allowing consumers to replay missed messages after reconnection. This aligns
|
|
||||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
|
||||||
|
|
||||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
|
||||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
|
||||||
recovering from dead consumers - ideal for the completion callback pattern.
|
|
||||||
|
|
||||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
|
||||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
|
||||||
|
|
||||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
|
||||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
|
||||||
transactional semantics without distributed coordination.
|
|
||||||
|
|
||||||
The consumer uses Redis Streams with consumer groups for reliable message
|
|
||||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
|
||||||
stale pending messages from dead consumers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from redis.exceptions import ResponseError
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteMessage(BaseModel):
|
|
||||||
"""Message format for operation completion notifications."""
|
|
||||||
|
|
||||||
operation_id: str
|
|
||||||
task_id: str
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionConsumer:
|
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
|
||||||
|
|
||||||
This consumer initializes its own Prisma client in start() to ensure
|
|
||||||
database operations work correctly within this async context.
|
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
|
||||||
messages reliably with automatic redelivery on failure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._consumer_task: asyncio.Task | None = None
|
|
||||||
self._running = False
|
|
||||||
self._prisma: Prisma | None = None
|
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
"""Start the completion consumer."""
|
|
||||||
if self._running:
|
|
||||||
logger.warning("Completion consumer already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create consumer group if it doesn't exist
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xgroup_create(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
id="0",
|
|
||||||
mkstream=True,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Created consumer group '{config.stream_consumer_group}' "
|
|
||||||
f"on stream '{config.stream_completion_name}'"
|
|
||||||
)
|
|
||||||
except ResponseError as e:
|
|
||||||
if "BUSYGROUP" in str(e):
|
|
||||||
logger.debug(
|
|
||||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
|
||||||
logger.info(
|
|
||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_prisma(self) -> Prisma:
|
|
||||||
"""Lazily initialize Prisma client on first use."""
|
|
||||||
if self._prisma is None:
|
|
||||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
|
||||||
self._prisma = Prisma(datasource={"url": database_url})
|
|
||||||
await self._prisma.connect()
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
|
||||||
return self._prisma
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
"""Stop the completion consumer."""
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
if self._consumer_task:
|
|
||||||
self._consumer_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._consumer_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._consumer_task = None
|
|
||||||
|
|
||||||
if self._prisma:
|
|
||||||
await self._prisma.disconnect()
|
|
||||||
self._prisma = None
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
|
||||||
"""Main message consumption loop with retry logic."""
|
|
||||||
max_retries = 10
|
|
||||||
retry_delay = 5 # seconds
|
|
||||||
retry_count = 0
|
|
||||||
block_timeout = 5000 # milliseconds
|
|
||||||
|
|
||||||
while self._running and retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Reset retry count on successful connection
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
while self._running:
|
|
||||||
# First, claim any stale pending messages from dead consumers
|
|
||||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
|
||||||
# claim them using XAUTOCLAIM
|
|
||||||
try:
|
|
||||||
claimed_result = await redis.xautoclaim(
|
|
||||||
name=config.stream_completion_name,
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
min_idle_time=config.stream_claim_min_idle_ms,
|
|
||||||
start_id="0-0",
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
|
||||||
if claimed_result and len(claimed_result) >= 2:
|
|
||||||
claimed_entries = claimed_result[1]
|
|
||||||
if claimed_entries:
|
|
||||||
logger.info(
|
|
||||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
|
||||||
)
|
|
||||||
for entry_id, data in claimed_entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
|
||||||
|
|
||||||
# Read new messages from the stream
|
|
||||||
messages = await redis.xreadgroup(
|
|
||||||
groupname=config.stream_consumer_group,
|
|
||||||
consumername=self._consumer_name,
|
|
||||||
streams={config.stream_completion_name: ">"},
|
|
||||||
block=block_timeout,
|
|
||||||
count=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for stream_name, entries in messages:
|
|
||||||
for entry_id, data in entries:
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
await self._process_entry(redis, entry_id, data)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Consumer cancelled")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
logger.error(
|
|
||||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
if self._running and retry_count < max_retries:
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
else:
|
|
||||||
logger.error("Max retries reached, stopping consumer")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _process_entry(
|
|
||||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Process a single stream entry and acknowledge it on success.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
redis: Redis client connection
|
|
||||||
entry_id: The stream entry ID
|
|
||||||
data: The entry data dict
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Handle the message
|
|
||||||
message_data = data.get("data")
|
|
||||||
if message_data:
|
|
||||||
await self._handle_message(
|
|
||||||
message_data.encode()
|
|
||||||
if isinstance(message_data, str)
|
|
||||||
else message_data
|
|
||||||
)
|
|
||||||
|
|
||||||
# Acknowledge the message after successful processing
|
|
||||||
await redis.xack(
|
|
||||||
config.stream_completion_name,
|
|
||||||
config.stream_consumer_group,
|
|
||||||
entry_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error processing completion message {entry_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Message remains in pending state and will be claimed by
|
|
||||||
# XAUTOCLAIM after min_idle_time expires
|
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
|
||||||
"""Handle a completion message using our own Prisma client."""
|
|
||||||
try:
|
|
||||||
data = orjson.loads(body)
|
|
||||||
message = OperationCompleteMessage(**data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse completion message: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id}, success={message.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find task in registry
|
|
||||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
|
||||||
if task is None:
|
|
||||||
task = await stream_registry.get_task(message.task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
|
||||||
f"(task_id={message.task_id})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
|
||||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guard against empty task fields
|
|
||||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Task has empty critical fields! "
|
|
||||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
|
||||||
f"tool_call_id={task.tool_call_id!r}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.success:
|
|
||||||
await self._handle_success(task, message)
|
|
||||||
else:
|
|
||||||
await self._handle_failure(task, message)
|
|
||||||
|
|
||||||
async def _handle_success(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_success(task, message.result, prisma)
|
|
||||||
|
|
||||||
async def _handle_failure(
|
|
||||||
self,
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
message: OperationCompleteMessage,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion."""
|
|
||||||
prisma = await self._ensure_prisma()
|
|
||||||
await process_operation_failure(task, message.error, prisma)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
|
||||||
_consumer: ChatCompletionConsumer | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def start_completion_consumer() -> None:
|
|
||||||
"""Start the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer is None:
|
|
||||||
_consumer = ChatCompletionConsumer()
|
|
||||||
await _consumer.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_completion_consumer() -> None:
|
|
||||||
"""Stop the global completion consumer."""
|
|
||||||
global _consumer
|
|
||||||
if _consumer:
|
|
||||||
await _consumer.stop()
|
|
||||||
_consumer = None
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_operation_complete(
|
|
||||||
operation_id: str,
|
|
||||||
task_id: str,
|
|
||||||
success: bool,
|
|
||||||
result: dict | str | None = None,
|
|
||||||
error: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Publish an operation completion message to Redis Streams.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID that completed.
|
|
||||||
task_id: The task ID associated with the operation.
|
|
||||||
success: Whether the operation succeeded.
|
|
||||||
result: The result data (for success).
|
|
||||||
error: The error message (for failure).
|
|
||||||
"""
|
|
||||||
message = OperationCompleteMessage(
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
success=success,
|
|
||||||
result=result,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
await redis.xadd(
|
|
||||||
config.stream_completion_name,
|
|
||||||
{"data": message.model_dump_json()},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
logger.info(f"Published completion for operation {operation_id}")
|
|
||||||
@@ -1,344 +0,0 @@
|
|||||||
"""Shared completion handling for operation success and failure.
|
|
||||||
|
|
||||||
This module provides common logic for handling operation completion from both:
|
|
||||||
- The Redis Streams consumer (completion_consumer.py)
|
|
||||||
- The HTTP webhook endpoint (routes.py)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
from prisma import Prisma
|
|
||||||
|
|
||||||
from . import service as chat_service
|
|
||||||
from . import stream_registry
|
|
||||||
from .response_model import StreamError, StreamToolOutputAvailable
|
|
||||||
from .tools.models import ErrorResponse
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that produce agent_json that needs to be saved to library
|
|
||||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
|
||||||
|
|
||||||
# Keys that should be stripped from agent_json when returning in error responses
|
|
||||||
SENSITIVE_KEYS = frozenset(
|
|
||||||
{
|
|
||||||
"api_key",
|
|
||||||
"apikey",
|
|
||||||
"api_secret",
|
|
||||||
"password",
|
|
||||||
"secret",
|
|
||||||
"credentials",
|
|
||||||
"credential",
|
|
||||||
"token",
|
|
||||||
"access_token",
|
|
||||||
"refresh_token",
|
|
||||||
"private_key",
|
|
||||||
"privatekey",
|
|
||||||
"auth",
|
|
||||||
"authorization",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_agent_json(obj: Any) -> Any:
|
|
||||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: The object to sanitize (dict, list, or primitive)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized copy with sensitive keys removed/redacted
|
|
||||||
"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {
|
|
||||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
|
||||||
for k, v in obj.items()
|
|
||||||
}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_sanitize_agent_json(item) for item in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageUpdateError(Exception):
|
|
||||||
"""Raised when updating a tool message in the database fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_tool_message(
|
|
||||||
session_id: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
content: str,
|
|
||||||
prisma_client: Prisma | None,
|
|
||||||
) -> None:
|
|
||||||
"""Update tool message in database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: The session ID
|
|
||||||
tool_call_id: The tool call ID to update
|
|
||||||
content: The new content for the message
|
|
||||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The caller should
|
|
||||||
handle this to avoid marking the task as completed with inconsistent state.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if prisma_client:
|
|
||||||
# Use provided Prisma client (for consumer with its own connection)
|
|
||||||
updated_count = await prisma_client.chatmessage.update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={"content": content},
|
|
||||||
)
|
|
||||||
# Check if any rows were updated - 0 means message not found
|
|
||||||
if updated_count == 0:
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use service function (for webhook endpoint)
|
|
||||||
await chat_service._update_pending_operation(
|
|
||||||
session_id=session_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
result=content,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
|
||||||
"""Serialize result to JSON string with sensible defaults.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result to serialize. Can be a dict, list, string,
|
|
||||||
number, boolean, or None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
|
||||||
only when result is explicitly None.
|
|
||||||
"""
|
|
||||||
if isinstance(result, str):
|
|
||||||
return result
|
|
||||||
if result is None:
|
|
||||||
return '{"status": "completed"}'
|
|
||||||
return orjson.dumps(result).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_agent_from_result(
|
|
||||||
result: dict[str, Any],
|
|
||||||
user_id: str | None,
|
|
||||||
tool_name: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Save agent to library if result contains agent_json.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: The result dict that may contain agent_json
|
|
||||||
user_id: The user ID to save the agent for
|
|
||||||
tool_name: The tool name (create_agent or edit_agent)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated result dict with saved agent details, or original result if no agent_json
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
|
||||||
return result
|
|
||||||
|
|
||||||
agent_json = result.get("agent_json")
|
|
||||||
if not agent_json:
|
|
||||||
logger.warning(
|
|
||||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .tools.agent_generator import save_agent_to_library
|
|
||||||
|
|
||||||
is_update = tool_name == "edit_agent"
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
agent_json, user_id, is_update=is_update
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
|
||||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return a response similar to AgentSavedResponse
|
|
||||||
return {
|
|
||||||
"type": "agent_saved",
|
|
||||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
|
||||||
"agent_id": created_graph.id,
|
|
||||||
"agent_name": created_graph.name,
|
|
||||||
"library_agent_id": library_agent.id,
|
|
||||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
|
||||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# Return error but don't fail the whole operation
|
|
||||||
# Sanitize agent_json to remove sensitive keys before returning
|
|
||||||
return {
|
|
||||||
"type": "error",
|
|
||||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
|
||||||
"error": str(e),
|
|
||||||
"agent_json": _sanitize_agent_json(agent_json),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_success(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
result: dict | str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle successful operation completion.
|
|
||||||
|
|
||||||
Publishes the result to the stream registry, updates the database,
|
|
||||||
generates LLM continuation, and marks the task as completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that completed
|
|
||||||
result: The result data from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ToolMessageUpdateError: If the database update fails. The task will be
|
|
||||||
marked as failed instead of completed to avoid inconsistent state.
|
|
||||||
"""
|
|
||||||
# For agent generation tools, save the agent to library
|
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
|
||||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
|
||||||
|
|
||||||
# Serialize result for output (only substitute default when result is exactly None)
|
|
||||||
result_output = result if result is not None else {"status": "completed"}
|
|
||||||
output_str = (
|
|
||||||
result_output
|
|
||||||
if isinstance(result_output, str)
|
|
||||||
else orjson.dumps(result_output).decode("utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Publish result to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=task.tool_call_id,
|
|
||||||
toolName=task.tool_name,
|
|
||||||
output=output_str,
|
|
||||||
success=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation in database
|
|
||||||
# If this fails, we must not continue to mark the task as completed
|
|
||||||
result_str = serialize_result(result)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=result_str,
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
|
||||||
"marking as failed instead of completed"
|
|
||||||
)
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText="Failed to save operation result to database"),
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Generate LLM continuation with streaming
|
|
||||||
try:
|
|
||||||
await chat_service._generate_llm_continuation_with_streaming(
|
|
||||||
session_id=task.session_id,
|
|
||||||
user_id=task.user_id,
|
|
||||||
task_id=task.task_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as completed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def process_operation_failure(
|
|
||||||
task: stream_registry.ActiveTask,
|
|
||||||
error: str | None,
|
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle failed operation completion.
|
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database with
|
|
||||||
the error response, and marks the task as failed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The active task that failed
|
|
||||||
error: The error message from the operation
|
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
"""
|
|
||||||
error_msg = error or "Operation failed"
|
|
||||||
|
|
||||||
# Publish error to stream registry
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task.task_id,
|
|
||||||
StreamError(errorText=error_msg),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update pending operation with error
|
|
||||||
# If this fails, we still continue to mark the task as failed
|
|
||||||
error_response = ErrorResponse(
|
|
||||||
message=error_msg,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await _update_tool_message(
|
|
||||||
session_id=task.session_id,
|
|
||||||
tool_call_id=task.tool_call_id,
|
|
||||||
content=error_response.model_dump_json(),
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
|
||||||
except ToolMessageUpdateError:
|
|
||||||
# DB update failed - log but continue with cleanup
|
|
||||||
logger.error(
|
|
||||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
|
||||||
"continuing with cleanup"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark task as failed and release Redis lock
|
|
||||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
|
||||||
try:
|
|
||||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
|
||||||
|
|
||||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
|
||||||
@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream registry configuration for SSE reconnection
|
|
||||||
stream_ttl: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
|
||||||
)
|
|
||||||
stream_max_length: int = Field(
|
|
||||||
default=10000,
|
|
||||||
description="Maximum number of messages to store per stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis Streams configuration for completion consumer
|
|
||||||
stream_completion_name: str = Field(
|
|
||||||
default="chat:completions",
|
|
||||||
description="Redis Stream name for operation completions",
|
|
||||||
)
|
|
||||||
stream_consumer_group: str = Field(
|
|
||||||
default="chat_consumers",
|
|
||||||
description="Consumer group name for completion stream",
|
|
||||||
)
|
|
||||||
stream_claim_min_idle_ms: int = Field(
|
|
||||||
default=60000,
|
|
||||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redis key prefixes for stream registry
|
|
||||||
task_meta_prefix: str = Field(
|
|
||||||
default="chat:task:meta:",
|
|
||||||
description="Prefix for task metadata hash keys",
|
|
||||||
)
|
|
||||||
task_stream_prefix: str = Field(
|
|
||||||
default="chat:stream:",
|
|
||||||
description="Prefix for task message stream keys",
|
|
||||||
)
|
|
||||||
task_op_prefix: str = Field(
|
|
||||||
default="chat:task:op:",
|
|
||||||
description="Prefix for operation ID to task ID mapping keys",
|
|
||||||
)
|
|
||||||
internal_api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Langfuse Prompt Management Configuration
|
# Langfuse Prompt Management Configuration
|
||||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||||
langfuse_prompt_name: str = Field(
|
langfuse_prompt_name: str = Field(
|
||||||
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = "https://openrouter.ai/api/v1"
|
v = "https://openrouter.ai/api/v1"
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("internal_api_key", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_internal_api_key(cls, v):
|
|
||||||
"""Get internal API key from environment if not provided."""
|
|
||||||
if v is None:
|
|
||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.START
|
type: ResponseType = ResponseType.START
|
||||||
messageId: str = Field(..., description="Unique message ID")
|
messageId: str = Field(..., description="Unique message ID")
|
||||||
taskId: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFinish(StreamBaseResponse):
|
class StreamFinish(StreamBaseResponse):
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from autogpt_libs import auth
|
from autogpt_libs import auth
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
from fastapi import APIRouter, Depends, Query, Security
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
|
|||||||
user_id: str | None
|
user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class ActiveStreamInfo(BaseModel):
|
|
||||||
"""Information about an active stream for reconnection."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
last_message_id: str # Redis Stream message ID for resumption
|
|
||||||
operation_id: str # Operation ID for completion tracking
|
|
||||||
tool_name: str # Name of the tool being executed
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDetailResponse(BaseModel):
|
class SessionDetailResponse(BaseModel):
|
||||||
"""Response model providing complete details for a chat session, including messages."""
|
"""Response model providing complete details for a chat session, including messages."""
|
||||||
|
|
||||||
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
|
|||||||
updated_at: str
|
updated_at: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
|
||||||
|
|
||||||
|
|
||||||
class SessionSummaryResponse(BaseModel):
|
class SessionSummaryResponse(BaseModel):
|
||||||
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class OperationCompleteRequest(BaseModel):
|
|
||||||
"""Request model for external completion webhook."""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
result: dict | str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Routes ==========
|
# ========== Routes ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -188,14 +166,13 @@ async def get_session(
|
|||||||
Retrieve the details of a specific chat session.
|
Retrieve the details of a specific chat session.
|
||||||
|
|
||||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||||
If there's an active stream for this session, returns the task_id for reconnection.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The unique identifier for the desired chat session.
|
session_id: The unique identifier for the desired chat session.
|
||||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
session = await get_chat_session(session_id, user_id)
|
session = await get_chat_session(session_id, user_id)
|
||||||
@@ -203,28 +180,11 @@ async def get_session(
|
|||||||
raise NotFoundError(f"Session {session_id} not found.")
|
raise NotFoundError(f"Session {session_id} not found.")
|
||||||
|
|
||||||
messages = [message.model_dump() for message in session.messages]
|
messages = [message.model_dump() for message in session.messages]
|
||||||
|
logger.info(
|
||||||
# Check if there's an active stream for this session
|
f"Returning session {session_id}: "
|
||||||
active_stream_info = None
|
f"message_count={len(messages)}, "
|
||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
f"roles={[m.get('role') for m in messages]}"
|
||||||
session_id, user_id
|
|
||||||
)
|
)
|
||||||
if active_task:
|
|
||||||
# Filter out the in-progress assistant message from the session response.
|
|
||||||
# The client will receive the complete assistant response through the SSE
|
|
||||||
# stream replay instead, preventing duplicate content.
|
|
||||||
if messages and messages[-1].get("role") == "assistant":
|
|
||||||
messages = messages[:-1]
|
|
||||||
|
|
||||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
|
||||||
# Since we filtered out the cached assistant message, the client needs
|
|
||||||
# the full stream to reconstruct the response.
|
|
||||||
active_stream_info = ActiveStreamInfo(
|
|
||||||
task_id=active_task.task_id,
|
|
||||||
last_message_id="0-0",
|
|
||||||
operation_id=active_task.operation_id,
|
|
||||||
tool_name=active_task.tool_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SessionDetailResponse(
|
return SessionDetailResponse(
|
||||||
id=session.session_id,
|
id=session.session_id,
|
||||||
@@ -232,7 +192,6 @@ async def get_session(
|
|||||||
updated_at=session.updated_at.isoformat(),
|
updated_at=session.updated_at.isoformat(),
|
||||||
user_id=session.user_id or None,
|
user_id=session.user_id or None,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
active_stream=active_stream_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -252,112 +211,49 @@ async def stream_chat_post(
|
|||||||
- Tool call UI elements (if invoked)
|
- Tool call UI elements (if invoked)
|
||||||
- Tool execution results
|
- Tool execution results
|
||||||
|
|
||||||
The AI generation runs in a background task that continues even if the client disconnects.
|
|
||||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
|
||||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The chat session identifier to associate with the streamed messages.
|
session_id: The chat session identifier to associate with the streamed messages.
|
||||||
request: Request body containing message, is_user_message, and optional context.
|
request: Request body containing message, is_user_message, and optional context.
|
||||||
user_id: Optional authenticated user ID.
|
user_id: Optional authenticated user ID.
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
StreamingResponse: SSE-formatted response chunks.
|
||||||
containing the task_id for reconnection.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
|
||||||
task_id = str(uuid_module.uuid4())
|
|
||||||
operation_id = str(uuid_module.uuid4())
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
|
||||||
tool_name="chat",
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
|
||||||
async def run_ai_generation():
|
|
||||||
try:
|
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
|
||||||
|
|
||||||
# Mark task as completed
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in background AI generation for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
|
|
||||||
# SSE endpoint that subscribes to the task's stream
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
subscriber_queue = None
|
chunk_count = 0
|
||||||
try:
|
first_chunk_type: str | None = None
|
||||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
session_id,
|
||||||
task_id=task_id,
|
request.message,
|
||||||
user_id=user_id,
|
is_user_message=request.is_user_message,
|
||||||
last_message_id="0-0", # Get all messages from the beginning
|
user_id=user_id,
|
||||||
)
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
|
context=request.context,
|
||||||
if subscriber_queue is None:
|
):
|
||||||
yield StreamFinish().to_sse()
|
if chunk_count < 3:
|
||||||
yield "data: [DONE]\n\n"
|
logger.info(
|
||||||
return
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
# Read from the subscriber queue and yield to SSE
|
"session_id": session_id,
|
||||||
while True:
|
"chunk_type": str(chunk.type),
|
||||||
try:
|
},
|
||||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
)
|
||||||
yield chunk.to_sse()
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
# Check for finish signal
|
chunk_count += 1
|
||||||
if isinstance(chunk, StreamFinish):
|
yield chunk.to_sse()
|
||||||
break
|
logger.info(
|
||||||
except asyncio.TimeoutError:
|
"Chat stream completed",
|
||||||
# Send heartbeat to keep connection alive
|
extra={
|
||||||
yield StreamHeartbeat().to_sse()
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
except GeneratorExit:
|
"first_chunk_type": first_chunk_type,
|
||||||
pass # Client disconnected - background task continues
|
},
|
||||||
except Exception as e:
|
)
|
||||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
# AI SDK protocol termination
|
||||||
finally:
|
yield "data: [DONE]\n\n"
|
||||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
|
||||||
if subscriber_queue is not None:
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(
|
|
||||||
task_id, subscriber_queue
|
|
||||||
)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
event_generator(),
|
||||||
@@ -470,251 +366,6 @@ async def session_assign_user(
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# ========== Task Streaming (SSE Reconnection) ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}/stream",
|
|
||||||
)
|
|
||||||
async def stream_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
last_message_id: str = Query(
|
|
||||||
default="0-0",
|
|
||||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Reconnect to a long-running task's SSE stream.
|
|
||||||
|
|
||||||
When a long-running operation (like agent generation) starts, the client
|
|
||||||
receives a task_id. If the connection drops, the client can reconnect
|
|
||||||
using this endpoint to resume receiving updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID from the operation_started response.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
|
||||||
"""
|
|
||||||
# Check task existence and expiry before subscribing
|
|
||||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
|
||||||
|
|
||||||
if error_code == "TASK_EXPIRED":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=410,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_EXPIRED",
|
|
||||||
"message": "This operation has expired. Please try again.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if error_code == "TASK_NOT_FOUND":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate ownership if task has an owner
|
|
||||||
if task and task.user_id and user_id != task.user_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail={
|
|
||||||
"code": "ACCESS_DENIED",
|
|
||||||
"message": "You do not have access to this task.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get subscriber queue from stream registry
|
|
||||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
|
||||||
task_id=task_id,
|
|
||||||
user_id=user_id,
|
|
||||||
last_message_id=last_message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscriber_queue is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail={
|
|
||||||
"code": "TASK_NOT_FOUND",
|
|
||||||
"message": f"Task {task_id} not found or access denied.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Wait for next chunk with timeout for heartbeats
|
|
||||||
chunk = await asyncio.wait_for(
|
|
||||||
subscriber_queue.get(), timeout=heartbeat_interval
|
|
||||||
)
|
|
||||||
yield chunk.to_sse()
|
|
||||||
|
|
||||||
# Check for finish signal
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Send heartbeat to keep connection alive
|
|
||||||
yield StreamHeartbeat().to_sse()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
# Unsubscribe when client disconnects or stream ends
|
|
||||||
try:
|
|
||||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
|
||||||
except Exception as unsub_err:
|
|
||||||
logger.error(
|
|
||||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_generator(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/tasks/{task_id}",
|
|
||||||
)
|
|
||||||
async def get_task_status(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None = Depends(auth.get_user_id),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get the status of a long-running task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: The task ID to check.
|
|
||||||
user_id: Authenticated user ID for ownership validation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If task_id is not found or user doesn't have access.
|
|
||||||
"""
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task.user_id and user_id != task.user_id:
|
|
||||||
raise NotFoundError(f"Task {task_id} not found.")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task.task_id,
|
|
||||||
"session_id": task.session_id,
|
|
||||||
"status": task.status,
|
|
||||||
"tool_name": task.tool_name,
|
|
||||||
"operation_id": task.operation_id,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== External Completion Webhook ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/operations/{operation_id}/complete",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def complete_operation(
|
|
||||||
operation_id: str,
|
|
||||||
request: OperationCompleteRequest,
|
|
||||||
x_api_key: str | None = Header(default=None),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
External completion webhook for long-running operations.
|
|
||||||
|
|
||||||
Called by Agent Generator (or other services) when an operation completes.
|
|
||||||
This triggers the stream registry to publish completion and continue LLM generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: The operation ID to complete.
|
|
||||||
request: Completion payload with success status and result/error.
|
|
||||||
x_api_key: Internal API key for authentication.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Status of the completion.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If API key is invalid or operation not found.
|
|
||||||
"""
|
|
||||||
# Validate internal API key - reject if not configured or invalid
|
|
||||||
if not config.internal_api_key:
|
|
||||||
logger.error(
|
|
||||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Webhook not available: internal API key not configured",
|
|
||||||
)
|
|
||||||
if x_api_key != config.internal_api_key:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
|
|
||||||
# Find task by operation_id
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Operation {operation_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Received completion webhook for operation {operation_id} "
|
|
||||||
f"(task_id={task.task_id}, success={request.success})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.success:
|
|
||||||
await process_operation_success(task, request.result)
|
|
||||||
else:
|
|
||||||
await process_operation_failure(task, request.error)
|
|
||||||
|
|
||||||
return {"status": "ok", "task_id": task.task_id}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Configuration ==========
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/ttl", status_code=200)
|
|
||||||
async def get_ttl_config() -> dict:
|
|
||||||
"""
|
|
||||||
Get the stream TTL configuration.
|
|
||||||
|
|
||||||
Returns the Time-To-Live settings for chat streams, which determines
|
|
||||||
how long clients can reconnect to an active stream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: TTL configuration with seconds and milliseconds values.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"stream_ttl_seconds": config.stream_ttl,
|
|
||||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Health Check ==========
|
# ========== Health Check ==========
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,704 +0,0 @@
|
|||||||
"""Stream registry for managing reconnectable SSE streams.
|
|
||||||
|
|
||||||
This module provides a registry for tracking active streaming tasks and their
|
|
||||||
messages. It uses Redis for all state management (no in-memory state), making
|
|
||||||
pods stateless and horizontally scalable.
|
|
||||||
|
|
||||||
Architecture:
|
|
||||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
|
||||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
|
||||||
|
|
||||||
Subscribers:
|
|
||||||
1. Replay missed messages from Redis Stream (XREAD)
|
|
||||||
2. Listen for live updates via blocking XREAD
|
|
||||||
3. No in-memory state required on the subscribing pod
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
|
||||||
|
|
||||||
from .config import ChatConfig
|
|
||||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
|
||||||
_local_tasks: dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
# Track listener tasks per subscriber queue for cleanup
|
|
||||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
|
||||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
|
||||||
|
|
||||||
# Timeout for putting chunks into subscriber queues (seconds)
|
|
||||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
|
||||||
QUEUE_PUT_TIMEOUT = 5.0
|
|
||||||
|
|
||||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
|
||||||
# Returns 1 if status was updated, 0 if already completed/failed
|
|
||||||
COMPLETE_TASK_SCRIPT = """
|
|
||||||
local current = redis.call("HGET", KEYS[1], "status")
|
|
||||||
if current == "running" then
|
|
||||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
|
||||||
return 1
|
|
||||||
end
|
|
||||||
return 0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActiveTask:
|
|
||||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
|
||||||
|
|
||||||
task_id: str
|
|
||||||
session_id: str
|
|
||||||
user_id: str | None
|
|
||||||
tool_call_id: str
|
|
||||||
tool_name: str
|
|
||||||
operation_id: str
|
|
||||||
status: Literal["running", "completed", "failed"] = "running"
|
|
||||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
asyncio_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_meta_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task metadata."""
|
|
||||||
return f"{config.task_meta_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_task_stream_key(task_id: str) -> str:
|
|
||||||
"""Get Redis key for task message stream."""
|
|
||||||
return f"{config.task_stream_prefix}{task_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
|
||||||
"""Get Redis key for operation_id to task_id mapping."""
|
|
||||||
return f"{config.task_op_prefix}{operation_id}"
|
|
||||||
|
|
||||||
|
|
||||||
async def create_task(
|
|
||||||
task_id: str,
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
tool_call_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
operation_id: str,
|
|
||||||
) -> ActiveTask:
|
|
||||||
"""Create a new streaming task in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Unique identifier for the task
|
|
||||||
session_id: Chat session ID
|
|
||||||
user_id: User ID (may be None for anonymous)
|
|
||||||
tool_call_id: Tool call ID from the LLM
|
|
||||||
tool_name: Name of the tool being executed
|
|
||||||
operation_id: Operation ID for webhook callbacks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created ActiveTask instance (metadata only)
|
|
||||||
"""
|
|
||||||
task = ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store metadata in Redis
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
|
|
||||||
await redis.hset( # type: ignore[misc]
|
|
||||||
meta_key,
|
|
||||||
mapping={
|
|
||||||
"task_id": task_id,
|
|
||||||
"session_id": session_id,
|
|
||||||
"user_id": user_id or "",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"status": task.status,
|
|
||||||
"created_at": task.created_at.isoformat(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await redis.expire(meta_key, config.stream_ttl)
|
|
||||||
|
|
||||||
# Create operation_id -> task_id mapping for webhook lookups
|
|
||||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
|
||||||
|
|
||||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
|
||||||
|
|
||||||
return task
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_chunk(
|
|
||||||
task_id: str,
|
|
||||||
chunk: StreamBaseResponse,
|
|
||||||
) -> str:
|
|
||||||
"""Publish a chunk to Redis Stream.
|
|
||||||
|
|
||||||
All delivery is via Redis Streams - no in-memory state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to publish to
|
|
||||||
chunk: The stream response chunk to publish
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The Redis Stream message ID
|
|
||||||
"""
|
|
||||||
chunk_json = chunk.model_dump_json()
|
|
||||||
message_id = "0-0"
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Write to Redis Stream for persistence and real-time delivery
|
|
||||||
raw_id = await redis.xadd(
|
|
||||||
stream_key,
|
|
||||||
{"data": chunk_json},
|
|
||||||
maxlen=config.stream_max_length,
|
|
||||||
)
|
|
||||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
|
||||||
|
|
||||||
# Set TTL on stream to match task metadata TTL
|
|
||||||
await redis.expire(stream_key, config.stream_ttl)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish chunk for task {task_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return message_id
|
|
||||||
|
|
||||||
|
|
||||||
async def subscribe_to_task(
|
|
||||||
task_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
last_message_id: str = "0-0",
|
|
||||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
|
||||||
"""Subscribe to a task's stream with replay of missed messages.
|
|
||||||
|
|
||||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to subscribe to
|
|
||||||
user_id: User ID for ownership validation
|
|
||||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
|
||||||
or user doesn't have access
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
logger.debug(f"Task {task_id} not found in Redis")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id:
|
|
||||||
if user_id != task_user_id:
|
|
||||||
logger.warning(
|
|
||||||
f"User {user_id} denied access to task {task_id} "
|
|
||||||
f"owned by {task_user_id}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
# Step 1: Replay messages from Redis Stream
|
|
||||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
|
||||||
|
|
||||||
replayed_count = 0
|
|
||||||
replay_last_id = last_message_id
|
|
||||||
if messages:
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
|
||||||
if "data" in msg_data:
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
await subscriber_queue.put(chunk)
|
|
||||||
replayed_count += 1
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to replay message: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
|
||||||
|
|
||||||
# Step 2: If task is still running, start stream listener for live updates
|
|
||||||
if task_status == "running":
|
|
||||||
listener_task = asyncio.create_task(
|
|
||||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
|
||||||
)
|
|
||||||
# Track listener task for cleanup on unsubscribe
|
|
||||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
|
||||||
else:
|
|
||||||
# Task is completed/failed - add finish marker
|
|
||||||
await subscriber_queue.put(StreamFinish())
|
|
||||||
|
|
||||||
return subscriber_queue
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_listener(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
last_replayed_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
|
||||||
|
|
||||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
|
||||||
when messages are published during the gap between replay and subscription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to listen for
|
|
||||||
subscriber_queue: Queue to deliver messages to
|
|
||||||
last_replayed_id: Last message ID from replay (continue from here)
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
# Track the last successfully delivered message ID for recovery hints
|
|
||||||
last_delivered_id = last_replayed_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
redis = await get_redis_async()
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
current_id = last_replayed_id
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Block for up to 30 seconds waiting for new messages
|
|
||||||
# This allows periodic checking if task is still running
|
|
||||||
messages = await redis.xread(
|
|
||||||
{stream_key: current_id}, block=30000, count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
# Timeout - check if task is still running
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
|
||||||
if status and status != "running":
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout delivering finish event for task {task_id}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
continue
|
|
||||||
|
|
||||||
for _stream_name, stream_messages in messages:
|
|
||||||
for msg_id, msg_data in stream_messages:
|
|
||||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
|
|
||||||
if "data" not in msg_data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk_data = orjson.loads(msg_data["data"])
|
|
||||||
chunk = _reconstruct_chunk(chunk_data)
|
|
||||||
if chunk:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(chunk),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
# Update last delivered ID on successful delivery
|
|
||||||
last_delivered_id = current_id
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Subscriber queue full for task {task_id}, "
|
|
||||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
|
||||||
)
|
|
||||||
# Send overflow error with recovery info
|
|
||||||
try:
|
|
||||||
overflow_error = StreamError(
|
|
||||||
errorText="Message delivery timeout - some messages may have been missed",
|
|
||||||
code="QUEUE_OVERFLOW",
|
|
||||||
details={
|
|
||||||
"last_delivered_id": last_delivered_id,
|
|
||||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
subscriber_queue.put_nowait(overflow_error)
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# Queue is completely stuck, nothing more we can do
|
|
||||||
logger.error(
|
|
||||||
f"Cannot deliver overflow error for task {task_id}, "
|
|
||||||
"queue completely blocked"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop listening on finish
|
|
||||||
if isinstance(chunk, StreamFinish):
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing stream message: {e}")
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
|
||||||
raise # Re-raise to propagate cancellation
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
|
||||||
# On error, send finish to unblock subscriber
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(
|
|
||||||
subscriber_queue.put(StreamFinish()),
|
|
||||||
timeout=QUEUE_PUT_TIMEOUT,
|
|
||||||
)
|
|
||||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
|
||||||
logger.warning(
|
|
||||||
f"Could not deliver finish event for task {task_id} after error"
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# Clean up listener task mapping on exit
|
|
||||||
_listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def mark_task_completed(
|
|
||||||
task_id: str,
|
|
||||||
status: Literal["completed", "failed"] = "completed",
|
|
||||||
) -> bool:
|
|
||||||
"""Mark a task as completed and publish finish event.
|
|
||||||
|
|
||||||
This is idempotent - calling multiple times with the same task_id is safe.
|
|
||||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
|
||||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to mark as completed
|
|
||||||
status: Final status ("completed" or "failed")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if task was newly marked completed, False if already completed/failed
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
|
|
||||||
# Atomic compare-and-swap: only update if status is "running"
|
|
||||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
|
||||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
|
||||||
|
|
||||||
if result == 0:
|
|
||||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
|
||||||
try:
|
|
||||||
await publish_chunk(task_id, StreamFinish())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
|
||||||
"Listeners will detect completion via status polling."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up local task reference if exists
|
|
||||||
_local_tasks.pop(task_id, None)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
|
||||||
"""Find a task by its operation ID.
|
|
||||||
|
|
||||||
Used by webhook callbacks to locate the task to update.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID to search for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
op_key = _get_operation_mapping_key(operation_id)
|
|
||||||
task_id = await redis.get(op_key)
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
|
||||||
return await get_task(task_id_str)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task(task_id: str) -> ActiveTask | None:
|
|
||||||
"""Get a task by its ID from Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ActiveTask if found, None otherwise
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_task_with_expiry_info(
|
|
||||||
task_id: str,
|
|
||||||
) -> tuple[ActiveTask | None, str | None]:
|
|
||||||
"""Get a task by its ID with expiration detection.
|
|
||||||
|
|
||||||
Returns (task, error_code) where error_code is:
|
|
||||||
- None if task found
|
|
||||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
|
||||||
- "TASK_NOT_FOUND" if neither exists
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID to look up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask or None, error_code or None)
|
|
||||||
"""
|
|
||||||
redis = await get_redis_async()
|
|
||||||
meta_key = _get_task_meta_key(task_id)
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
|
||||||
|
|
||||||
if not meta:
|
|
||||||
# Check if stream still has data (metadata expired but stream hasn't)
|
|
||||||
stream_len = await redis.xlen(stream_key)
|
|
||||||
if stream_len > 0:
|
|
||||||
return None, "TASK_EXPIRED"
|
|
||||||
return None, "TASK_NOT_FOUND"
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=meta.get("task_id", ""),
|
|
||||||
session_id=meta.get("session_id", ""),
|
|
||||||
user_id=meta.get("user_id", "") or None,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_active_task_for_session(
|
|
||||||
session_id: str,
|
|
||||||
user_id: str | None = None,
|
|
||||||
) -> tuple[ActiveTask | None, str]:
|
|
||||||
"""Get the active (running) task for a session, if any.
|
|
||||||
|
|
||||||
Scans Redis for tasks matching the session_id with status="running".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: Session ID to look up
|
|
||||||
user_id: User ID for ownership validation (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
|
||||||
"""
|
|
||||||
|
|
||||||
redis = await get_redis_async()
|
|
||||||
|
|
||||||
# Scan Redis for task metadata keys
|
|
||||||
cursor = 0
|
|
||||||
tasks_checked = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
cursor, keys = await redis.scan(
|
|
||||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
|
||||||
)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
tasks_checked += 1
|
|
||||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
|
||||||
if not meta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
|
||||||
task_session_id = meta.get("session_id", "")
|
|
||||||
task_status = meta.get("status", "")
|
|
||||||
task_user_id = meta.get("user_id", "") or None
|
|
||||||
task_id = meta.get("task_id", "")
|
|
||||||
|
|
||||||
if task_session_id == session_id and task_status == "running":
|
|
||||||
# Validate ownership - if task has an owner, requester must match
|
|
||||||
if task_user_id and user_id != task_user_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
|
||||||
stream_key = _get_task_stream_key(task_id)
|
|
||||||
last_id = "0-0"
|
|
||||||
try:
|
|
||||||
messages = await redis.xrevrange(stream_key, count=1)
|
|
||||||
if messages:
|
|
||||||
msg_id = messages[0][0]
|
|
||||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get last message ID: {e}")
|
|
||||||
|
|
||||||
return (
|
|
||||||
ActiveTask(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=task_session_id,
|
|
||||||
user_id=task_user_id,
|
|
||||||
tool_call_id=meta.get("tool_call_id", ""),
|
|
||||||
tool_name=meta.get("tool_name", ""),
|
|
||||||
operation_id=meta.get("operation_id", ""),
|
|
||||||
status="running",
|
|
||||||
),
|
|
||||||
last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cursor == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
return None, "0-0"
|
|
||||||
|
|
||||||
|
|
||||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
|
||||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_data: Parsed JSON data from Redis
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reconstructed response object, or None if unknown type
|
|
||||||
"""
|
|
||||||
from .response_model import (
|
|
||||||
ResponseType,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamHeartbeat,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
StreamUsage,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Map response types to their corresponding classes
|
|
||||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
|
||||||
ResponseType.START.value: StreamStart,
|
|
||||||
ResponseType.FINISH.value: StreamFinish,
|
|
||||||
ResponseType.TEXT_START.value: StreamTextStart,
|
|
||||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
|
||||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
|
||||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
|
||||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
|
||||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
|
||||||
ResponseType.ERROR.value: StreamError,
|
|
||||||
ResponseType.USAGE.value: StreamUsage,
|
|
||||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
|
||||||
}
|
|
||||||
|
|
||||||
chunk_type = chunk_data.get("type")
|
|
||||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
if chunk_class is None:
|
|
||||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return chunk_class(**chunk_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
|
||||||
"""Track the asyncio.Task for a task (local reference only).
|
|
||||||
|
|
||||||
This is just for cleanup purposes - the task state is in Redis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
asyncio_task: The asyncio Task to track
|
|
||||||
"""
|
|
||||||
_local_tasks[task_id] = asyncio_task
|
|
||||||
|
|
||||||
|
|
||||||
async def unsubscribe_from_task(
|
|
||||||
task_id: str,
|
|
||||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
|
||||||
) -> None:
|
|
||||||
"""Clean up when a subscriber disconnects.
|
|
||||||
|
|
||||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
|
||||||
to prevent resource leaks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_id: Task ID
|
|
||||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
|
||||||
"""
|
|
||||||
queue_id = id(subscriber_queue)
|
|
||||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
|
||||||
|
|
||||||
if listener_entry is None:
|
|
||||||
logger.debug(
|
|
||||||
f"No listener task found for task {task_id} queue {queue_id} "
|
|
||||||
"(may have already completed)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
stored_task_id, listener_task = listener_entry
|
|
||||||
|
|
||||||
if stored_task_id != task_id:
|
|
||||||
logger.warning(
|
|
||||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
|
||||||
f"found {stored_task_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if listener_task.done():
|
|
||||||
logger.debug(f"Listener task for task {task_id} already completed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cancel the listener task
|
|
||||||
listener_task.cancel()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for the task to be cancelled with a timeout
|
|
||||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Expected - the task was successfully cancelled
|
|
||||||
pass
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
|
||||||
|
|
||||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
|
||||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
|||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
from .find_agent import FindAgentTool
|
from .find_agent import FindAgentTool
|
||||||
from .find_block import FindBlockTool
|
from .find_block import FindBlockTool
|
||||||
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
|
|||||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||||
"add_understanding": AddUnderstandingTool(),
|
"add_understanding": AddUnderstandingTool(),
|
||||||
"create_agent": CreateAgentTool(),
|
"create_agent": CreateAgentTool(),
|
||||||
"customize_agent": CustomizeAgentTool(),
|
|
||||||
"edit_agent": EditAgentTool(),
|
"edit_agent": EditAgentTool(),
|
||||||
"find_agent": FindAgentTool(),
|
"find_agent": FindAgentTool(),
|
||||||
"find_block": FindBlockTool(),
|
"find_block": FindBlockTool(),
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .core import (
|
|||||||
DecompositionStep,
|
DecompositionStep,
|
||||||
LibraryAgentSummary,
|
LibraryAgentSummary,
|
||||||
MarketplaceAgentSummary,
|
MarketplaceAgentSummary,
|
||||||
customize_template,
|
|
||||||
decompose_goal,
|
decompose_goal,
|
||||||
enrich_library_agents_from_steps,
|
enrich_library_agents_from_steps,
|
||||||
extract_search_terms_from_steps,
|
extract_search_terms_from_steps,
|
||||||
@@ -20,7 +19,6 @@ from .core import (
|
|||||||
get_library_agent_by_graph_id,
|
get_library_agent_by_graph_id,
|
||||||
get_library_agent_by_id,
|
get_library_agent_by_id,
|
||||||
get_library_agents_for_generation,
|
get_library_agents_for_generation,
|
||||||
graph_to_json,
|
|
||||||
json_to_graph,
|
json_to_graph,
|
||||||
save_agent_to_library,
|
save_agent_to_library,
|
||||||
search_marketplace_agents_for_generation,
|
search_marketplace_agents_for_generation,
|
||||||
@@ -38,7 +36,6 @@ __all__ = [
|
|||||||
"LibraryAgentSummary",
|
"LibraryAgentSummary",
|
||||||
"MarketplaceAgentSummary",
|
"MarketplaceAgentSummary",
|
||||||
"check_external_service_health",
|
"check_external_service_health",
|
||||||
"customize_template",
|
|
||||||
"decompose_goal",
|
"decompose_goal",
|
||||||
"enrich_library_agents_from_steps",
|
"enrich_library_agents_from_steps",
|
||||||
"extract_search_terms_from_steps",
|
"extract_search_terms_from_steps",
|
||||||
@@ -51,7 +48,6 @@ __all__ = [
|
|||||||
"get_library_agent_by_id",
|
"get_library_agent_by_id",
|
||||||
"get_library_agents_for_generation",
|
"get_library_agents_for_generation",
|
||||||
"get_user_message_for_error",
|
"get_user_message_for_error",
|
||||||
"graph_to_json",
|
|
||||||
"is_external_service_configured",
|
"is_external_service_configured",
|
||||||
"json_to_graph",
|
"json_to_graph",
|
||||||
"save_agent_to_library",
|
"save_agent_to_library",
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from backend.data.graph import (
|
|||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
customize_template_external,
|
|
||||||
decompose_goal_external,
|
decompose_goal_external,
|
||||||
generate_agent_external,
|
generate_agent_external,
|
||||||
generate_agent_patch_external,
|
generate_agent_patch_external,
|
||||||
@@ -550,21 +549,15 @@ async def decompose_goal(
|
|||||||
async def generate_agent(
|
async def generate_agent(
|
||||||
instructions: DecompositionResult | dict[str, Any],
|
instructions: DecompositionResult | dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Generate agent JSON from instructions.
|
"""Generate agent JSON from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams
|
|
||||||
completion notification)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
|
||||||
and SSE delivery)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -572,13 +565,8 @@ async def generate_agent(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent")
|
logger.info("Calling external Agent Generator service for generate_agent")
|
||||||
result = await generate_agent_external(
|
result = await generate_agent_external(
|
||||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
dict(instructions), _to_dict_list(library_agents)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Don't modify async response
|
|
||||||
if result and result.get("status") == "accepted":
|
|
||||||
return result
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
return result
|
return result
|
||||||
@@ -752,15 +740,32 @@ async def save_agent_to_library(
|
|||||||
return created_graph, library_agents[0]
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
async def get_agent_as_json(
|
||||||
"""Convert a Graph object to JSON format for the agent generator.
|
agent_id: str, user_id: str | None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch an agent and convert to JSON format for editing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph object to convert
|
agent_id: Graph ID or library agent ID
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
|
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
|
if not graph and user_id:
|
||||||
|
try:
|
||||||
|
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||||
|
graph = await get_graph(
|
||||||
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
|
)
|
||||||
|
except NotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
return None
|
||||||
|
|
||||||
nodes = []
|
nodes = []
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
nodes.append(
|
nodes.append(
|
||||||
@@ -797,41 +802,10 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_agent_as_json(
|
|
||||||
agent_id: str, user_id: str | None
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch an agent and convert to JSON format for editing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Graph ID or library agent ID
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent as JSON dict or None if not found
|
|
||||||
"""
|
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
|
||||||
|
|
||||||
if not graph and user_id:
|
|
||||||
try:
|
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return graph_to_json(graph)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_agent_patch(
|
async def generate_agent_patch(
|
||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[AgentSummary] | None = None,
|
library_agents: list[AgentSummary] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Update an existing agent using natural language.
|
"""Update an existing agent using natural language.
|
||||||
|
|
||||||
@@ -844,12 +818,10 @@ async def generate_agent_patch(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
error dict {"type": "error", ...}, or None on unexpected error
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||||
@@ -857,43 +829,5 @@ async def generate_agent_patch(
|
|||||||
_check_service_configured()
|
_check_service_configured()
|
||||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||||
return await generate_agent_patch_external(
|
return await generate_agent_patch_external(
|
||||||
update_request,
|
update_request, current_agent, _to_dict_list(library_agents)
|
||||||
current_agent,
|
|
||||||
_to_dict_list(library_agents),
|
|
||||||
operation_id,
|
|
||||||
task_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def customize_template(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Customize a template/marketplace agent using natural language.
|
|
||||||
|
|
||||||
This is used when users want to modify a template or marketplace agent
|
|
||||||
to fit their specific needs before adding it to their library.
|
|
||||||
|
|
||||||
The external Agent Generator service handles:
|
|
||||||
- Understanding the modification request
|
|
||||||
- Applying changes to the template
|
|
||||||
- Fixing and validating the result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
|
||||||
error dict {"type": "error", ...}, or None on unexpected error
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
|
||||||
"""
|
|
||||||
_check_service_configured()
|
|
||||||
logger.info("Calling external Agent Generator service for customize_template")
|
|
||||||
return await customize_template_external(
|
|
||||||
template_agent, modification_request, context
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -139,10 +139,11 @@ async def decompose_goal_external(
|
|||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
if context:
|
# Build the request payload
|
||||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
@@ -212,45 +213,24 @@ async def decompose_goal_external(
|
|||||||
async def generate_agent_external(
|
async def generate_agent_external(
|
||||||
instructions: dict[str, Any],
|
instructions: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate an agent from instructions.
|
"""Call the external service to generate an agent from instructions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instructions: Structured instructions from decompose_goal
|
instructions: Structured instructions from decompose_goal
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {"instructions": instructions}
|
payload: dict[str, Any] = {"instructions": instructions}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/generate-agent", json=payload)
|
response = await client.post("/api/generate-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -282,8 +262,6 @@ async def generate_agent_patch_external(
|
|||||||
update_request: str,
|
update_request: str,
|
||||||
current_agent: dict[str, Any],
|
current_agent: dict[str, Any],
|
||||||
library_agents: list[dict[str, Any]] | None = None,
|
library_agents: list[dict[str, Any]] | None = None,
|
||||||
operation_id: str | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Call the external service to generate a patch for an existing agent.
|
"""Call the external service to generate a patch for an existing agent.
|
||||||
|
|
||||||
@@ -291,40 +269,21 @@ async def generate_agent_patch_external(
|
|||||||
update_request: Natural language description of changes
|
update_request: Natural language description of changes
|
||||||
current_agent: Current agent JSON
|
current_agent: Current agent JSON
|
||||||
library_agents: User's library agents available for sub-agent composition
|
library_agents: User's library agents available for sub-agent composition
|
||||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
|
||||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
# Build request payload
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"update_request": update_request,
|
"update_request": update_request,
|
||||||
"current_agent_json": current_agent,
|
"current_agent_json": current_agent,
|
||||||
}
|
}
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
if operation_id and task_id:
|
|
||||||
payload["operation_id"] = operation_id
|
|
||||||
payload["task_id"] = task_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/update-agent", json=payload)
|
response = await client.post("/api/update-agent", json=payload)
|
||||||
|
|
||||||
# Handle 202 Accepted for async processing
|
|
||||||
if response.status_code == 202:
|
|
||||||
logger.info(
|
|
||||||
f"Agent Generator accepted async update request "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": "accepted",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
@@ -368,77 +327,6 @@ async def generate_agent_patch_external(
|
|||||||
return _create_error_response(error_msg, "unexpected_error")
|
return _create_error_response(error_msg, "unexpected_error")
|
||||||
|
|
||||||
|
|
||||||
async def customize_template_external(
|
|
||||||
template_agent: dict[str, Any],
|
|
||||||
modification_request: str,
|
|
||||||
context: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Call the external service to customize a template/marketplace agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template_agent: The template agent JSON to customize
|
|
||||||
modification_request: Natural language description of customizations
|
|
||||||
context: Additional context (e.g., answers to previous questions)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
|
||||||
"""
|
|
||||||
client = _get_client()
|
|
||||||
|
|
||||||
request = modification_request
|
|
||||||
if context:
|
|
||||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"template_agent_json": template_agent,
|
|
||||||
"modification_request": request,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await client.post("/api/template-modification", json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
if not data.get("success"):
|
|
||||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
|
||||||
error_type = data.get("error_type", "unknown")
|
|
||||||
logger.error(
|
|
||||||
f"Agent Generator template customization failed: {error_msg} "
|
|
||||||
f"(type: {error_type})"
|
|
||||||
)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
|
|
||||||
# Check if it's clarifying questions
|
|
||||||
if data.get("type") == "clarifying_questions":
|
|
||||||
return {
|
|
||||||
"type": "clarifying_questions",
|
|
||||||
"questions": data.get("questions", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if it's an error passed through
|
|
||||||
if data.get("type") == "error":
|
|
||||||
return _create_error_response(
|
|
||||||
data.get("error", "Unknown error"),
|
|
||||||
data.get("error_type", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Otherwise return the customized agent JSON
|
|
||||||
return data.get("agent_json")
|
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
error_type, error_msg = _classify_http_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
error_type, error_msg = _classify_request_error(e)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, error_type)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return _create_error_response(error_msg, "unexpected_error")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||||
"""Get available blocks from the external service.
|
"""Get available blocks from the external service.
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -99,10 +98,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
@@ -224,12 +219,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_json = await generate_agent(
|
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||||
decomposition_result,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
@@ -273,19 +263,6 @@ class CreateAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if agent_json.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent generation delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent generation started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = agent_json.get("name", "Generated Agent")
|
agent_name = agent_json.get("name", "Generated Agent")
|
||||||
agent_description = agent_json.get("description", "")
|
agent_description = agent_json.get("description", "")
|
||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
|
|||||||
@@ -1,337 +0,0 @@
|
|||||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
|
||||||
|
|
||||||
from .agent_generator import (
|
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
customize_template,
|
|
||||||
get_user_message_for_error,
|
|
||||||
graph_to_json,
|
|
||||||
save_agent_to_library,
|
|
||||||
)
|
|
||||||
from .base import BaseTool
|
|
||||||
from .models import (
|
|
||||||
AgentPreviewResponse,
|
|
||||||
AgentSavedResponse,
|
|
||||||
ClarificationNeededResponse,
|
|
||||||
ClarifyingQuestion,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "customize_agent"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Customize a marketplace or template agent using natural language. "
|
|
||||||
"Takes an existing agent from the marketplace and modifies it based on "
|
|
||||||
"the user's requirements before adding to their library."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_long_running(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The marketplace agent ID in format 'creator/slug' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer'). "
|
|
||||||
"Get this from find_agent results."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"modifications": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Natural language description of how to customize the agent. "
|
|
||||||
"Be specific about what changes you want to make."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"context": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"Additional context or answers to previous clarifying questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"save": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"Whether to save the customized agent to the user's library. "
|
|
||||||
"Default is true. Set to false for preview only."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["agent_id", "modifications"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
"""Execute the customize_agent tool.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Parse the agent ID to get creator/slug
|
|
||||||
2. Fetch the template agent from the marketplace
|
|
||||||
3. Call customize_template with the modification request
|
|
||||||
4. Preview or save based on the save parameter
|
|
||||||
"""
|
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not agent_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
|
||||||
error="missing_agent_id",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not modifications:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please describe how you want to customize this agent.",
|
|
||||||
error="missing_modifications",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
|
||||||
"Expected format is 'creator/agent-name' "
|
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
|
||||||
),
|
|
||||||
error="invalid_agent_id_format",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
creator_username, agent_slug = parts
|
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
|
||||||
try:
|
|
||||||
agent_details = await store_db.get_store_agent_details(
|
|
||||||
username=creator_username, agent_name=agent_slug
|
|
||||||
)
|
|
||||||
except AgentNotFoundError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
|
||||||
"Please check the agent ID and try again."
|
|
||||||
),
|
|
||||||
error="agent_not_found",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
|
||||||
error="fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not agent_details.store_listing_version_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
|
||||||
"Please try a different agent."
|
|
||||||
),
|
|
||||||
error="no_version_available",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the full agent graph
|
|
||||||
try:
|
|
||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
|
||||||
template_agent = graph_to_json(graph)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
|
||||||
error="graph_fetch_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call customize_template
|
|
||||||
try:
|
|
||||||
result = await customize_template(
|
|
||||||
template_agent=template_agent,
|
|
||||||
modification_request=modifications,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
except AgentGeneratorNotConfiguredError:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Agent customization is not available. "
|
|
||||||
"The Agent Generator service is not configured."
|
|
||||||
),
|
|
||||||
error="service_not_configured",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent due to a service error. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_service_error",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Failed to customize the agent. "
|
|
||||||
"The agent generation service may be unavailable or timed out. "
|
|
||||||
"Please try again."
|
|
||||||
),
|
|
||||||
error="customization_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle error response
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
|
||||||
error_msg = result.get("error", "Unknown error")
|
|
||||||
error_type = result.get("error_type", "unknown")
|
|
||||||
user_message = get_user_message_for_error(
|
|
||||||
error_type,
|
|
||||||
operation="customize the agent",
|
|
||||||
llm_parse_message=(
|
|
||||||
"The AI had trouble customizing the agent. "
|
|
||||||
"Please try again or simplify your request."
|
|
||||||
),
|
|
||||||
validation_message=(
|
|
||||||
"The customized agent failed validation. "
|
|
||||||
"Please try rephrasing your request."
|
|
||||||
),
|
|
||||||
error_details=error_msg,
|
|
||||||
)
|
|
||||||
return ErrorResponse(
|
|
||||||
message=user_message,
|
|
||||||
error=f"customization_failed:{error_type}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle clarifying questions
|
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
|
||||||
questions = result.get("questions") or []
|
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
|
||||||
)
|
|
||||||
questions = []
|
|
||||||
return ClarificationNeededResponse(
|
|
||||||
message=(
|
|
||||||
"I need some more information to customize this agent. "
|
|
||||||
"Please answer the following questions:"
|
|
||||||
),
|
|
||||||
questions=[
|
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
|
||||||
error="unexpected_response_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
customized_agent = result
|
|
||||||
|
|
||||||
agent_name = customized_agent.get(
|
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
|
||||||
)
|
|
||||||
agent_description = customized_agent.get("description", "")
|
|
||||||
nodes = customized_agent.get("nodes")
|
|
||||||
links = customized_agent.get("links")
|
|
||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
|
||||||
|
|
||||||
if not save:
|
|
||||||
return AgentPreviewResponse(
|
|
||||||
message=(
|
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
|
||||||
f"The customized agent has {node_count} blocks. "
|
|
||||||
f"Review it and call customize_agent with save=true to save it."
|
|
||||||
),
|
|
||||||
agent_json=customized_agent,
|
|
||||||
agent_name=agent_name,
|
|
||||||
description=agent_description,
|
|
||||||
node_count=node_count,
|
|
||||||
link_count=link_count,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="You must be logged in to save agents.",
|
|
||||||
error="auth_required",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to user's library
|
|
||||||
try:
|
|
||||||
created_graph, library_agent = await save_agent_to_library(
|
|
||||||
customized_agent, user_id, is_update=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentSavedResponse(
|
|
||||||
message=(
|
|
||||||
f"Customized agent '{created_graph.name}' "
|
|
||||||
f"(based on '{agent_details.agent_name}') "
|
|
||||||
f"has been saved to your library!"
|
|
||||||
),
|
|
||||||
agent_id=created_graph.id,
|
|
||||||
agent_name=created_graph.name,
|
|
||||||
library_agent_id=library_agent.id,
|
|
||||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
|
||||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving customized agent: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Failed to save the customized agent. Please try again.",
|
|
||||||
error="save_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -17,7 +17,6 @@ from .base import BaseTool
|
|||||||
from .models import (
|
from .models import (
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
AgentSavedResponse,
|
AgentSavedResponse,
|
||||||
AsyncProcessingResponse,
|
|
||||||
ClarificationNeededResponse,
|
ClarificationNeededResponse,
|
||||||
ClarifyingQuestion,
|
ClarifyingQuestion,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -105,10 +104,6 @@ class EditAgentTool(BaseTool):
|
|||||||
save = kwargs.get("save", True)
|
save = kwargs.get("save", True)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
|
||||||
operation_id = kwargs.get("_operation_id")
|
|
||||||
task_id = kwargs.get("_task_id")
|
|
||||||
|
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
@@ -154,11 +149,7 @@ class EditAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
update_request,
|
update_request, current_agent, library_agents
|
||||||
current_agent,
|
|
||||||
library_agents,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -178,20 +169,6 @@ class EditAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Agent Generator accepted for async processing
|
|
||||||
if result.get("status") == "accepted":
|
|
||||||
logger.info(
|
|
||||||
f"Agent edit delegated to async processing "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
return AsyncProcessingResponse(
|
|
||||||
message="Agent edit started. You'll be notified when it's complete.",
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the result is an error from the external service
|
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
if isinstance(result, dict) and result.get("type") == "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_STARTED = "operation_started"
|
OPERATION_STARTED = "operation_started"
|
||||||
OPERATION_PENDING = "operation_pending"
|
OPERATION_PENDING = "operation_pending"
|
||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
|
||||||
|
|
||||||
|
|
||||||
# Base response model
|
# Base response model
|
||||||
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
|
|||||||
has_external_trigger: bool | None = None
|
has_external_trigger: bool | None = None
|
||||||
new_output: bool | None = None
|
new_output: bool | None = None
|
||||||
graph_id: str | None = None
|
graph_id: str | None = None
|
||||||
inputs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Input schema for the agent, including field names, types, and defaults",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentsFoundResponse(ToolResponseBase):
|
class AgentsFoundResponse(ToolResponseBase):
|
||||||
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
|
|||||||
details: dict[str, Any] | None = None
|
details: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class InputValidationErrorResponse(ToolResponseBase):
|
|
||||||
"""Response when run_agent receives unknown input fields."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
|
||||||
unrecognized_fields: list[str] = Field(
|
|
||||||
description="List of input field names that were not recognized"
|
|
||||||
)
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
description="The agent's valid input schema for reference"
|
|
||||||
)
|
|
||||||
graph_id: str | None = None
|
|
||||||
graph_version: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# Agent output models
|
# Agent output models
|
||||||
class ExecutionOutputInfo(BaseModel):
|
class ExecutionOutputInfo(BaseModel):
|
||||||
"""Summary of a single execution's outputs."""
|
"""Summary of a single execution's outputs."""
|
||||||
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
|
|||||||
|
|
||||||
This is returned immediately to the client while the operation continues
|
This is returned immediately to the client while the operation continues
|
||||||
to execute. The user can close the tab and check back later.
|
to execute. The user can close the tab and check back later.
|
||||||
|
|
||||||
The task_id can be used to reconnect to the SSE stream via
|
|
||||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||||
operation_id: str
|
operation_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
task_id: str | None = None # For SSE reconnection
|
|
||||||
|
|
||||||
|
|
||||||
class OperationPendingResponse(ToolResponseBase):
|
class OperationPendingResponse(ToolResponseBase):
|
||||||
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
|
|||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class AsyncProcessingResponse(ToolResponseBase):
|
|
||||||
"""Response when an operation has been delegated to async processing.
|
|
||||||
|
|
||||||
This is returned by tools when the external service accepts the request
|
|
||||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
|
||||||
consumer will handle the result when the external service completes.
|
|
||||||
|
|
||||||
The status field is specifically "accepted" to allow the long-running tool
|
|
||||||
handler to detect this response and skip LLM continuation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
|
||||||
status: str = "accepted" # Must be "accepted" for detection
|
|
||||||
operation_id: str | None = None
|
|
||||||
task_id: str | None = None
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from .models import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ExecutionOptions,
|
ExecutionOptions,
|
||||||
ExecutionStartedResponse,
|
ExecutionStartedResponse,
|
||||||
InputValidationErrorResponse,
|
|
||||||
SetupInfo,
|
SetupInfo,
|
||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
|
|||||||
input_properties = graph.input_schema.get("properties", {})
|
input_properties = graph.input_schema.get("properties", {})
|
||||||
required_fields = set(graph.input_schema.get("required", []))
|
required_fields = set(graph.input_schema.get("required", []))
|
||||||
provided_inputs = set(params.inputs.keys())
|
provided_inputs = set(params.inputs.keys())
|
||||||
valid_fields = set(input_properties.keys())
|
|
||||||
|
|
||||||
# Check for unknown input fields
|
|
||||||
unrecognized_fields = provided_inputs - valid_fields
|
|
||||||
if unrecognized_fields:
|
|
||||||
return InputValidationErrorResponse(
|
|
||||||
message=(
|
|
||||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
|
||||||
f"Agent was not executed. Please use the correct field names from the schema."
|
|
||||||
),
|
|
||||||
session_id=session_id,
|
|
||||||
unrecognized_fields=sorted(unrecognized_fields),
|
|
||||||
inputs=graph.input_schema,
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||||
# always show what's available first so user can decide
|
# always show what's available first so user can decide
|
||||||
|
|||||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
|||||||
# Should return error about missing schedule_name
|
# Should return error about missing schedule_name
|
||||||
assert result_data.get("type") == "error"
|
assert result_data.get("type") == "error"
|
||||||
assert "schedule_name" in result_data["message"].lower()
|
assert "schedule_name" in result_data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
|
||||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
|
||||||
user = setup_test_data["user"]
|
|
||||||
store_submission = setup_test_data["store_submission"]
|
|
||||||
|
|
||||||
tool = RunAgentTool()
|
|
||||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
|
||||||
session = make_session(user_id=user.id)
|
|
||||||
|
|
||||||
# Execute with unknown input field names
|
|
||||||
response = await tool.execute(
|
|
||||||
user_id=user.id,
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
tool_call_id=str(uuid.uuid4()),
|
|
||||||
username_agent_slug=agent_marketplace_id,
|
|
||||||
inputs={
|
|
||||||
"unknown_field": "some value",
|
|
||||||
"another_unknown": "another value",
|
|
||||||
},
|
|
||||||
session=session,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
assert hasattr(response, "output")
|
|
||||||
assert isinstance(response.output, str)
|
|
||||||
result_data = orjson.loads(response.output)
|
|
||||||
|
|
||||||
# Should return input_validation_error type with unrecognized fields
|
|
||||||
assert result_data.get("type") == "input_validation_error"
|
|
||||||
assert "unrecognized_fields" in result_data
|
|
||||||
assert set(result_data["unrecognized_fields"]) == {
|
|
||||||
"another_unknown",
|
|
||||||
"unknown_field",
|
|
||||||
}
|
|
||||||
assert "inputs" in result_data # Contains the valid schema
|
|
||||||
assert "Agent was not executed" in result_data["message"]
|
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
block: Any,
|
block: Any,
|
||||||
input_data: dict[str, Any] | None = None,
|
|
||||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
"""
|
"""
|
||||||
Check if user has required credentials for a block.
|
Check if user has required credentials for a block.
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
block: Block to check credentials for
|
|
||||||
input_data: Input data for the block (used to determine provider via discriminator)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
missing_credentials: list[CredentialsMetaInput] = []
|
||||||
input_data = input_data or {}
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
# Get credential field info from block's input schema
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
|
|||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
for field_name, field_info in credentials_fields_info.items():
|
||||||
effective_field_info = field_info
|
# field_info.provider is a frozenset of acceptable providers
|
||||||
if field_info.discriminator and field_info.discriminator_mapping:
|
# field_info.supported_types is a frozenset of acceptable types
|
||||||
# Get discriminator from input, falling back to schema default
|
|
||||||
discriminator_value = input_data.get(field_info.discriminator)
|
|
||||||
if discriminator_value is None:
|
|
||||||
field = block.input_schema.model_fields.get(
|
|
||||||
field_info.discriminator
|
|
||||||
)
|
|
||||||
if field and field.default is not PydanticUndefined:
|
|
||||||
discriminator_value = field.default
|
|
||||||
|
|
||||||
if (
|
|
||||||
discriminator_value
|
|
||||||
and discriminator_value in field_info.discriminator_mapping
|
|
||||||
):
|
|
||||||
effective_field_info = field_info.discriminate(discriminator_value)
|
|
||||||
logger.debug(
|
|
||||||
f"Discriminated provider for {field_name}: "
|
|
||||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in effective_field_info.provider
|
if cred.provider in field_info.provider
|
||||||
and cred.type in effective_field_info.supported_types
|
and cred.type in field_info.supported_types
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create a placeholder for the missing credential
|
# Create a placeholder for the missing credential
|
||||||
provider = next(iter(effective_field_info.provider), "unknown")
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
missing_credentials.append(
|
missing_credentials.append(
|
||||||
CredentialsMetaInput(
|
CredentialsMetaInput(
|
||||||
id=field_name,
|
id=field_name,
|
||||||
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||||
|
|
||||||
|
# Check credentials
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
user_id, block, input_data
|
user_id, block
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
|
|||||||
@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
cleanup_embeddings: list,
|
cleanup_embeddings: list,
|
||||||
):
|
):
|
||||||
"""Test unified search pagination works correctly."""
|
"""Test unified search pagination works correctly."""
|
||||||
# Use a unique search term to avoid matching other test data
|
|
||||||
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
# Create multiple items
|
# Create multiple items
|
||||||
content_ids = []
|
content_ids = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
content_type=ContentType.BLOCK,
|
content_type=ContentType.BLOCK,
|
||||||
content_id=content_id,
|
content_id=content_id,
|
||||||
embedding=mock_embedding,
|
embedding=mock_embedding,
|
||||||
searchable_text=f"{unique_term} item number {i}",
|
searchable_text=f"pagination test item number {i}",
|
||||||
metadata={"index": i},
|
metadata={"index": i},
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
page1_results, total1 = await unified_hybrid_search(
|
page1_results, total1 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
|
|||||||
|
|
||||||
# Get second page
|
# Get second page
|
||||||
page2_results, total2 = await unified_hybrid_search(
|
page2_results, total2 = await unified_hybrid_search(
|
||||||
query=unique_term,
|
query="pagination test",
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=2,
|
page=2,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
|
|||||||
@@ -40,10 +40,6 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
|
||||||
start_completion_consumer,
|
|
||||||
stop_completion_consumer,
|
|
||||||
)
|
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
|||||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||||
|
|
||||||
# Start chat completion consumer for Redis Streams notifications
|
|
||||||
try:
|
|
||||||
await start_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
|
||||||
|
|
||||||
with launch_darkly_context():
|
with launch_darkly_context():
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Stop chat completion consumer
|
|
||||||
try:
|
|
||||||
await stop_completion_consumer()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await shutdown_cloud_storage_handler()
|
await shutdown_cloud_storage_handler()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
execution_bus = AsyncRedisExecutionEventBus()
|
execution_bus = AsyncRedisExecutionEventBus()
|
||||||
notification_bus = AsyncRedisNotificationEventBus()
|
notification_bus = AsyncRedisNotificationEventBus()
|
||||||
|
|
||||||
try:
|
async def execution_worker():
|
||||||
|
async for event in execution_bus.listen("*"):
|
||||||
|
await manager.send_execution_update(event)
|
||||||
|
|
||||||
async def execution_worker():
|
async def notification_worker():
|
||||||
async for event in execution_bus.listen("*"):
|
async for notification in notification_bus.listen("*"):
|
||||||
await manager.send_execution_update(event)
|
await manager.send_notification(
|
||||||
|
user_id=notification.user_id,
|
||||||
|
payload=notification.payload,
|
||||||
|
)
|
||||||
|
|
||||||
async def notification_worker():
|
await asyncio.gather(execution_worker(), notification_worker())
|
||||||
async for notification in notification_bus.listen("*"):
|
|
||||||
await manager.send_notification(
|
|
||||||
user_id=notification.user_id,
|
|
||||||
payload=notification.payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
|
||||||
finally:
|
|
||||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
|
||||||
await execution_bus.close()
|
|
||||||
await notification_bus.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
|||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_context, estimate_token_count
|
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
@@ -634,18 +634,11 @@ async def llm_call(
|
|||||||
context_window = llm_model.context_window
|
context_window = llm_model.context_window
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
result = await compress_context(
|
prompt = compress_prompt(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=llm_model.context_window // 2,
|
target_tokens=llm_model.context_window // 2,
|
||||||
client=None, # Truncation-only, no LLM summarization
|
lossy_ok=True,
|
||||||
reserve=0, # Caller handles response token budget separately
|
|
||||||
)
|
)
|
||||||
if result.error:
|
|
||||||
logger.warning(
|
|
||||||
f"Prompt compression did not meet target: {result.error}. "
|
|
||||||
f"Proceeding with {result.token_count} tokens."
|
|
||||||
)
|
|
||||||
prompt = result.messages
|
|
||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
|
|||||||
@@ -873,13 +873,14 @@ def is_block_auth_configured(
|
|||||||
|
|
||||||
|
|
||||||
async def initialize_blocks() -> None:
|
async def initialize_blocks() -> None:
|
||||||
|
# First, sync all provider costs to blocks
|
||||||
|
# Imported here to avoid circular import
|
||||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||||
from backend.util.retry import func_retry
|
|
||||||
|
|
||||||
sync_all_provider_costs()
|
sync_all_provider_costs()
|
||||||
|
|
||||||
@func_retry
|
for cls in get_blocks().values():
|
||||||
async def sync_block_to_db(block: Block) -> None:
|
block = cls()
|
||||||
existing_block = await AgentBlock.prisma().find_first(
|
existing_block = await AgentBlock.prisma().find_first(
|
||||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||||
)
|
)
|
||||||
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
|
|||||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
continue
|
||||||
|
|
||||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||||
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
failed_blocks: list[str] = []
|
|
||||||
for cls in get_blocks().values():
|
|
||||||
block = cls()
|
|
||||||
try:
|
|
||||||
await sync_block_to_db(block)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to sync block {block.name} to database: {e}. "
|
|
||||||
"Block is still available in memory.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
failed_blocks.append(block.name)
|
|
||||||
|
|
||||||
if failed_blocks:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
|
||||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||||
|
|||||||
@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||||
def __init__(self):
|
|
||||||
self._pubsub: AsyncPubSub | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
async def connection(self) -> redis.AsyncRedis:
|
async def connection(self) -> redis.AsyncRedis:
|
||||||
return await redis.get_redis_async()
|
return await redis.get_redis_async()
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the PubSub connection if it exists."""
|
|
||||||
if self._pubsub is not None:
|
|
||||||
try:
|
|
||||||
await self._pubsub.close()
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
|
||||||
finally:
|
|
||||||
self._pubsub = None
|
|
||||||
|
|
||||||
async def publish_event(self, event: M, channel_key: str):
|
async def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
"""
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
Publish an event to Redis. Gracefully handles connection failures
|
||||||
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
await self.connection, channel_key
|
await self.connection, channel_key
|
||||||
)
|
)
|
||||||
assert isinstance(pubsub, AsyncPubSub)
|
assert isinstance(pubsub, AsyncPubSub)
|
||||||
self._pubsub = pubsub
|
|
||||||
|
|
||||||
if "*" in channel_key:
|
if "*" in channel_key:
|
||||||
await pubsub.psubscribe(full_channel_name)
|
await pubsub.psubscribe(full_channel_name)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from backend.data.analytics import (
|
|||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
get_marketplace_graphs_for_monitoring,
|
get_marketplace_graphs_for_monitoring,
|
||||||
)
|
)
|
||||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
|
||||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
create_graph_execution,
|
create_graph_execution,
|
||||||
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
# OAuth
|
|
||||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
# OAuth
|
|
||||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -24,9 +24,11 @@ from dotenv import load_dotenv
|
|||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import MetaData, create_engine
|
from sqlalchemy import MetaData, create_engine
|
||||||
|
|
||||||
|
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
from backend.data.execution import GraphExecutionWithNodes
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
|
from backend.data.onboarding import increment_onboarding_runs
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.monitoring import (
|
from backend.monitoring import (
|
||||||
NotificationJobArgs,
|
NotificationJobArgs,
|
||||||
@@ -36,11 +38,7 @@ from backend.monitoring import (
|
|||||||
report_execution_accuracy_alerts,
|
report_execution_accuracy_alerts,
|
||||||
report_late_executions,
|
report_late_executions,
|
||||||
)
|
)
|
||||||
from backend.util.clients import (
|
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||||
get_database_manager_async_client,
|
|
||||||
get_database_manager_client,
|
|
||||||
get_scheduler_client,
|
|
||||||
)
|
|
||||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotFoundError,
|
GraphNotFoundError,
|
||||||
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
|
|||||||
async def _execute_graph(**kwargs):
|
async def _execute_graph(**kwargs):
|
||||||
args = GraphExecutionJobArgs(**kwargs)
|
args = GraphExecutionJobArgs(**kwargs)
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
db = get_database_manager_async_client()
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||||
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
|
|||||||
inputs=args.input_data,
|
inputs=args.input_data,
|
||||||
graph_credentials_inputs=args.input_credentials,
|
graph_credentials_inputs=args.input_credentials,
|
||||||
)
|
)
|
||||||
await db.increment_onboarding_runs(args.user_id)
|
await increment_onboarding_runs(args.user_id)
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||||
@@ -196,9 +193,11 @@ async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
|
|||||||
user_id=args.user_id,
|
user_id=args.user_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.warning(
|
||||||
f"Unable to unschedule graph: {args.graph_id} as this is an old job with no associated schedule_id please remove manually"
|
f"Old scheduled job for graph {args.graph_id} (user {args.user_id}) "
|
||||||
|
f"has no schedule_id, attempting targeted cleanup"
|
||||||
)
|
)
|
||||||
|
await _cleanup_old_schedules_without_id(args.graph_id, args.user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_graph_not_available(
|
async def _handle_graph_not_available(
|
||||||
@@ -241,6 +240,35 @@ async def _cleanup_orphaned_schedules_for_graph(graph_id: str, user_id: str) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_old_schedules_without_id(graph_id: str, user_id: str) -> None:
|
||||||
|
"""Remove only schedules that have no schedule_id in their job args.
|
||||||
|
|
||||||
|
Unlike _cleanup_orphaned_schedules_for_graph (which removes ALL schedules
|
||||||
|
for a graph), this only targets legacy jobs created before schedule_id was
|
||||||
|
added to GraphExecutionJobArgs, preserving any valid newer schedules.
|
||||||
|
"""
|
||||||
|
scheduler_client = get_scheduler_client()
|
||||||
|
schedules = await scheduler_client.get_execution_schedules(
|
||||||
|
graph_id=graph_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for schedule in schedules:
|
||||||
|
if schedule.schedule_id is not None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await scheduler_client.delete_schedule(
|
||||||
|
schedule_id=schedule.id, user_id=user_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Cleaned up old schedule {schedule.id} (no schedule_id) "
|
||||||
|
f"for graph {graph_id}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to delete old schedule {schedule.id} for graph {graph_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_expired_files():
|
def cleanup_expired_files():
|
||||||
"""Clean up expired files from cloud storage."""
|
"""Clean up expired files from cloud storage."""
|
||||||
# Wait for completion
|
# Wait for completion
|
||||||
@@ -249,13 +277,8 @@ def cleanup_expired_files():
|
|||||||
|
|
||||||
def cleanup_oauth_tokens():
|
def cleanup_oauth_tokens():
|
||||||
"""Clean up expired OAuth tokens from the database."""
|
"""Clean up expired OAuth tokens from the database."""
|
||||||
|
|
||||||
# Wait for completion
|
# Wait for completion
|
||||||
async def _cleanup():
|
run_async(cleanup_expired_oauth_tokens())
|
||||||
db = get_database_manager_async_client()
|
|
||||||
return await db.cleanup_expired_oauth_tokens()
|
|
||||||
|
|
||||||
run_async(_cleanup())
|
|
||||||
|
|
||||||
|
|
||||||
def execution_accuracy_alerts():
|
def execution_accuracy_alerts():
|
||||||
|
|||||||
@@ -1,19 +1,10 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from typing import Any
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
# CONSTANTS #
|
# CONSTANTS #
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
|
|||||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Carefully truncate tool message content while preserving tool structure.
|
Carefully truncate tool message content while preserving tool structure.
|
||||||
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
Only truncates tool_result content, leaves tool_use intact.
|
||||||
"""
|
"""
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
|
|
||||||
# OpenAI-style tool message: role="tool" with string content
|
|
||||||
if msg.get("role") == "tool" and isinstance(content, str):
|
|
||||||
if _tok_len(content, enc) > max_tokens:
|
|
||||||
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Anthropic-style: list content with tool_result items
|
|
||||||
if not isinstance(content, list):
|
if not isinstance(content, list):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
|||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
|
||||||
|
def compress_prompt(
|
||||||
|
messages: list[dict],
|
||||||
|
target_tokens: int,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
reserve: int = 2_048,
|
||||||
|
start_cap: int = 8_192,
|
||||||
|
floor_cap: int = 128,
|
||||||
|
lossy_ok: bool = True,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Shrink *messages* so that::
|
||||||
|
|
||||||
|
token_count(prompt) + reserve ≤ target_tokens
|
||||||
|
|
||||||
|
Strategy
|
||||||
|
--------
|
||||||
|
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||||
|
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||||
|
*content* of every message except the first and last. Tool shells
|
||||||
|
are included: we keep the envelope but shorten huge payloads.
|
||||||
|
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||||
|
messages working outward from the centre, **skipping** any message
|
||||||
|
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||||
|
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||||
|
*last* message bodies down to `floor_cap` tokens.
|
||||||
|
4. If the prompt is *still* too large:
|
||||||
|
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||||
|
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
messages Complete chat history (will be deep-copied).
|
||||||
|
model Model name; passed to tiktoken to pick the right
|
||||||
|
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||||
|
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||||
|
forthcoming answer.
|
||||||
|
reserve How many tokens you want to leave available for that
|
||||||
|
answer (`max_tokens` in your subsequent completion call).
|
||||||
|
start_cap Initial per-message truncation ceiling (tokens).
|
||||||
|
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||||
|
lossy_ok If *True* return best-effort prompt instead of raising
|
||||||
|
after all trim passes have been exhausted.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[dict] – A *new* messages list that abides by the rules above.
|
||||||
|
"""
|
||||||
|
enc = encoding_for_model(model) # best-match tokenizer
|
||||||
|
msgs = deepcopy(messages) # never mutate caller
|
||||||
|
|
||||||
|
def total_tokens() -> int:
|
||||||
|
"""Current size of *msgs* in tokens."""
|
||||||
|
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||||
|
|
||||||
|
original_token_count = total_tokens()
|
||||||
|
|
||||||
|
if original_token_count + reserve <= target_tokens:
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
# ---- STEP 0 : normalise content --------------------------------------
|
||||||
|
# Convert non-string payloads to strings so token counting is coherent.
|
||||||
|
for i, m in enumerate(msgs):
|
||||||
|
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Keep first and last messages intact (unless they're tool messages)
|
||||||
|
if i == 0 or i == len(msgs) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||||
|
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||||
|
if len(content_str) > 20_000:
|
||||||
|
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||||
|
m["content"] = content_str
|
||||||
|
|
||||||
|
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for m in msgs[1:-1]: # keep first & last intact
|
||||||
|
if _is_tool_message(m):
|
||||||
|
# For tool messages, only truncate tool result content, preserve structure
|
||||||
|
_truncate_tool_message_content(m, enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _is_objective_message(m):
|
||||||
|
# Never truncate objective messages - they contain the core task
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = m.get("content") or ""
|
||||||
|
if _tok_len(content, enc) > cap:
|
||||||
|
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||||
|
cap //= 2 # tighten the screw
|
||||||
|
|
||||||
|
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||||
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
|
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||||
|
deletable_indices = []
|
||||||
|
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||||
|
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||||
|
deletable_indices.append(i)
|
||||||
|
|
||||||
|
if not deletable_indices:
|
||||||
|
break # nothing more we can drop
|
||||||
|
|
||||||
|
# Delete from center outward - find the index closest to center
|
||||||
|
centre = len(msgs) // 2
|
||||||
|
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||||
|
del msgs[to_delete]
|
||||||
|
|
||||||
|
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for idx in (0, -1): # first and last
|
||||||
|
if _is_tool_message(msgs[idx]):
|
||||||
|
# For tool messages at first/last position, truncate tool result content only
|
||||||
|
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = msgs[idx].get("content") or ""
|
||||||
|
if _tok_len(text, enc) > cap:
|
||||||
|
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||||
|
cap //= 2 # tighten the screw
|
||||||
|
|
||||||
|
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||||
|
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||||
|
raise ValueError(
|
||||||
|
"compress_prompt: prompt still exceeds budget "
|
||||||
|
f"({total_tokens() + reserve} > {target_tokens})."
|
||||||
|
)
|
||||||
|
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(
|
def estimate_token_count(
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
*,
|
*,
|
||||||
@@ -175,8 +293,7 @@ def estimate_token_count(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
token_model = _normalize_model_for_tokenizer(model)
|
enc = encoding_for_model(model) # best-match tokenizer
|
||||||
enc = encoding_for_model(token_model)
|
|
||||||
return sum(_msg_tokens(m, enc) for m in messages)
|
return sum(_msg_tokens(m, enc) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
@@ -198,543 +315,6 @@ def estimate_token_count_str(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
token_model = _normalize_model_for_tokenizer(model)
|
enc = encoding_for_model(model) # best-match tokenizer
|
||||||
enc = encoding_for_model(token_model)
|
|
||||||
text = json.dumps(text) if not isinstance(text, str) else text
|
text = json.dumps(text) if not isinstance(text, str) else text
|
||||||
return _tok_len(text, enc)
|
return _tok_len(text, enc)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
|
||||||
# UNIFIED CONTEXT COMPRESSION #
|
|
||||||
# ---------------------------------------------------------------------------#
|
|
||||||
|
|
||||||
# Default thresholds
|
|
||||||
DEFAULT_TOKEN_THRESHOLD = 120_000
|
|
||||||
DEFAULT_KEEP_RECENT = 15
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CompressResult:
|
|
||||||
"""Result of context compression."""
|
|
||||||
|
|
||||||
messages: list[dict]
|
|
||||||
token_count: int
|
|
||||||
was_compacted: bool
|
|
||||||
error: str | None = None
|
|
||||||
original_token_count: int = 0
|
|
||||||
messages_summarized: int = 0
|
|
||||||
messages_dropped: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_model_for_tokenizer(model: str) -> str:
|
|
||||||
"""Normalize model name for tiktoken tokenizer selection."""
|
|
||||||
if "/" in model:
|
|
||||||
model = model.split("/")[-1]
|
|
||||||
if "claude" in model.lower() or not any(
|
|
||||||
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
|
||||||
):
|
|
||||||
return "gpt-4o"
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
|
||||||
"""
|
|
||||||
Extract tool_call IDs from an assistant message.
|
|
||||||
|
|
||||||
Supports both formats:
|
|
||||||
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
|
||||||
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of tool_call IDs found in the message.
|
|
||||||
"""
|
|
||||||
ids: set[str] = set()
|
|
||||||
if msg.get("role") != "assistant":
|
|
||||||
return ids
|
|
||||||
|
|
||||||
# OpenAI format: tool_calls array
|
|
||||||
if msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tc_id = tc.get("id")
|
|
||||||
if tc_id:
|
|
||||||
ids.add(tc_id)
|
|
||||||
|
|
||||||
# Anthropic format: content list with tool_use blocks
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
|
||||||
tc_id = block.get("id")
|
|
||||||
if tc_id:
|
|
||||||
ids.add(tc_id)
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
|
||||||
"""
|
|
||||||
Extract tool_call IDs that this message is responding to.
|
|
||||||
|
|
||||||
Supports both formats:
|
|
||||||
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
|
||||||
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Set of tool_call IDs this message responds to.
|
|
||||||
"""
|
|
||||||
ids: set[str] = set()
|
|
||||||
|
|
||||||
# OpenAI format: role=tool with tool_call_id
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tc_id = msg.get("tool_call_id")
|
|
||||||
if tc_id:
|
|
||||||
ids.add(tc_id)
|
|
||||||
|
|
||||||
# Anthropic format: content list with tool_result blocks
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
||||||
tc_id = block.get("tool_use_id")
|
|
||||||
if tc_id:
|
|
||||||
ids.add(tc_id)
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
|
|
||||||
def _is_tool_response_message(msg: dict) -> bool:
|
|
||||||
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
|
||||||
# OpenAI format
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
return True
|
|
||||||
# Anthropic format
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_orphan_tool_responses(
|
|
||||||
messages: list[dict], orphan_ids: set[str]
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
|
||||||
|
|
||||||
Supports both OpenAI and Anthropic formats.
|
|
||||||
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
|
||||||
filters out only the orphan blocks instead of dropping the entire message.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for msg in messages:
|
|
||||||
# OpenAI format: role=tool - drop entire message if orphan
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tc_id = msg.get("tool_call_id")
|
|
||||||
if tc_id and tc_id in orphan_ids:
|
|
||||||
continue
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Anthropic format: content list may have mixed tool_result blocks
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
has_tool_results = any(
|
|
||||||
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
|
||||||
)
|
|
||||||
if has_tool_results:
|
|
||||||
# Filter out orphan tool_result blocks, keep valid ones
|
|
||||||
filtered_content = [
|
|
||||||
block
|
|
||||||
for block in content
|
|
||||||
if not (
|
|
||||||
isinstance(block, dict)
|
|
||||||
and block.get("type") == "tool_result"
|
|
||||||
and block.get("tool_use_id") in orphan_ids
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# Only keep message if it has remaining content
|
|
||||||
if filtered_content:
|
|
||||||
msg = msg.copy()
|
|
||||||
msg["content"] = filtered_content
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
result.append(msg)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_tool_pairs_intact(
|
|
||||||
recent_messages: list[dict],
|
|
||||||
all_messages: list[dict],
|
|
||||||
start_index: int,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
|
||||||
|
|
||||||
When slicing messages for context compaction, a naive slice can separate
|
|
||||||
an assistant message containing tool_calls from its corresponding tool
|
|
||||||
response messages. This causes API validation errors (e.g., Anthropic's
|
|
||||||
"unexpected tool_use_id found in tool_result blocks").
|
|
||||||
|
|
||||||
This function checks for orphan tool responses in the slice and extends
|
|
||||||
backwards to include their corresponding assistant messages.
|
|
||||||
|
|
||||||
Supports both formats:
|
|
||||||
- OpenAI: tool_calls array + role="tool" responses
|
|
||||||
- Anthropic: tool_use blocks + tool_result blocks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recent_messages: The sliced messages to validate
|
|
||||||
all_messages: The complete message list (for looking up missing assistants)
|
|
||||||
start_index: The index in all_messages where recent_messages begins
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A potentially extended list of messages with tool pairs intact
|
|
||||||
"""
|
|
||||||
if not recent_messages:
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Collect all tool_call_ids from assistant messages in the slice
|
|
||||||
available_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
|
||||||
|
|
||||||
# Find orphan tool responses (responses whose tool_call_id is missing)
|
|
||||||
orphan_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
response_ids = _extract_tool_response_ids_from_message(msg)
|
|
||||||
for tc_id in response_ids:
|
|
||||||
if tc_id not in available_tool_call_ids:
|
|
||||||
orphan_tool_call_ids.add(tc_id)
|
|
||||||
|
|
||||||
if not orphan_tool_call_ids:
|
|
||||||
# No orphans, slice is valid
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Find the assistant messages that contain the orphan tool_call_ids
|
|
||||||
# Search backwards from start_index in all_messages
|
|
||||||
messages_to_prepend: list[dict] = []
|
|
||||||
for i in range(start_index - 1, -1, -1):
|
|
||||||
msg = all_messages[i]
|
|
||||||
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
|
||||||
if msg_tool_ids & orphan_tool_call_ids:
|
|
||||||
# This assistant message has tool_calls we need
|
|
||||||
# Also collect its contiguous tool responses that follow it
|
|
||||||
assistant_and_responses: list[dict] = [msg]
|
|
||||||
|
|
||||||
# Scan forward from this assistant to collect tool responses
|
|
||||||
for j in range(i + 1, start_index):
|
|
||||||
following_msg = all_messages[j]
|
|
||||||
following_response_ids = _extract_tool_response_ids_from_message(
|
|
||||||
following_msg
|
|
||||||
)
|
|
||||||
if following_response_ids and following_response_ids & msg_tool_ids:
|
|
||||||
assistant_and_responses.append(following_msg)
|
|
||||||
elif not _is_tool_response_message(following_msg):
|
|
||||||
# Stop at first non-tool-response message
|
|
||||||
break
|
|
||||||
|
|
||||||
# Prepend the assistant and its tool responses (maintain order)
|
|
||||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
|
||||||
# Mark these as found
|
|
||||||
orphan_tool_call_ids -= msg_tool_ids
|
|
||||||
# Also add this assistant's tool_call_ids to available set
|
|
||||||
available_tool_call_ids |= msg_tool_ids
|
|
||||||
|
|
||||||
if not orphan_tool_call_ids:
|
|
||||||
# Found all missing assistants
|
|
||||||
break
|
|
||||||
|
|
||||||
if orphan_tool_call_ids:
|
|
||||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
|
||||||
# This shouldn't happen in normal operation but handles edge cases
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
|
||||||
"Removing orphan tool responses."
|
|
||||||
)
|
|
||||||
recent_messages = _remove_orphan_tool_responses(
|
|
||||||
recent_messages, orphan_tool_call_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
if messages_to_prepend:
|
|
||||||
logger.info(
|
|
||||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
|
||||||
f"tool_call/tool_response pairs"
|
|
||||||
)
|
|
||||||
return messages_to_prepend + recent_messages
|
|
||||||
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
|
|
||||||
async def _summarize_messages_llm(
|
|
||||||
messages: list[dict],
|
|
||||||
client: AsyncOpenAI,
|
|
||||||
model: str,
|
|
||||||
timeout: float = 30.0,
|
|
||||||
) -> str:
|
|
||||||
"""Summarize messages using an LLM."""
|
|
||||||
conversation = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
if content and role in ("user", "assistant", "tool"):
|
|
||||||
conversation.append(f"{role.upper()}: {content}")
|
|
||||||
|
|
||||||
conversation_text = "\n\n".join(conversation)
|
|
||||||
|
|
||||||
if not conversation_text:
|
|
||||||
return "No conversation history available."
|
|
||||||
|
|
||||||
# Limit to ~100k chars for safety
|
|
||||||
MAX_CHARS = 100_000
|
|
||||||
if len(conversation_text) > MAX_CHARS:
|
|
||||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
|
||||||
|
|
||||||
response = await client.with_options(timeout=timeout).chat.completions.create(
|
|
||||||
model=model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"Create a detailed summary of the conversation so far. "
|
|
||||||
"This summary will be used as context when continuing the conversation.\n\n"
|
|
||||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
|
||||||
"- User requests and their explicit goals\n"
|
|
||||||
"- Your approach and key decisions made\n"
|
|
||||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
|
||||||
"- Errors encountered and resolutions applied\n\n"
|
|
||||||
"You MUST include ALL of the following sections:\n\n"
|
|
||||||
"## 1. Primary Request and Intent\n"
|
|
||||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
|
||||||
"## 2. Key Technical Concepts\n"
|
|
||||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
|
||||||
"## 3. Files and Resources Involved\n"
|
|
||||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
|
||||||
"## 4. Errors and Fixes\n"
|
|
||||||
"Problems encountered, error messages, and their resolutions. "
|
|
||||||
"Include any user feedback on fixes.\n\n"
|
|
||||||
"## 5. Problem Solving\n"
|
|
||||||
"Issues that have been resolved and how they were addressed.\n\n"
|
|
||||||
"## 6. All User Messages\n"
|
|
||||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
|
||||||
"## 7. Pending Tasks\n"
|
|
||||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
|
||||||
"## 8. Current Work\n"
|
|
||||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
|
||||||
"## 9. Next Steps\n"
|
|
||||||
"What should happen next, aligned with the user's most recent requests. "
|
|
||||||
"Include verbatim quotes of recent instructions if relevant."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
|
||||||
],
|
|
||||||
max_tokens=1500,
|
|
||||||
temperature=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.choices[0].message.content or "No summary available."
|
|
||||||
|
|
||||||
|
|
||||||
async def compress_context(
|
|
||||||
messages: list[dict],
|
|
||||||
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
|
||||||
*,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
client: AsyncOpenAI | None = None,
|
|
||||||
keep_recent: int = DEFAULT_KEEP_RECENT,
|
|
||||||
reserve: int = 2_048,
|
|
||||||
start_cap: int = 8_192,
|
|
||||||
floor_cap: int = 128,
|
|
||||||
) -> CompressResult:
|
|
||||||
"""
|
|
||||||
Unified context compression that combines summarization and truncation strategies.
|
|
||||||
|
|
||||||
Strategy (in order):
|
|
||||||
1. **LLM summarization** – If client provided, summarize old messages into a
|
|
||||||
single context message while keeping recent messages intact. This is the
|
|
||||||
primary strategy for chat service.
|
|
||||||
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
|
||||||
bloated message content (tool outputs, large pastes). Preserves all messages
|
|
||||||
but shortens their content. Primary strategy when client=None (LLM blocks).
|
|
||||||
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
|
||||||
outward, skipping tool messages and objective messages.
|
|
||||||
4. **First/last trim** – Truncate first and last message content as last resort.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
messages Complete chat history (will be deep-copied).
|
|
||||||
target_tokens Hard ceiling for prompt size.
|
|
||||||
model Model name for tokenization and summarization.
|
|
||||||
client AsyncOpenAI client. If provided, enables LLM summarization
|
|
||||||
as the first strategy. If None, skips to truncation strategies.
|
|
||||||
keep_recent Number of recent messages to preserve during summarization.
|
|
||||||
reserve Tokens to reserve for model response.
|
|
||||||
start_cap Initial per-message truncation ceiling (tokens).
|
|
||||||
floor_cap Lowest cap before moving to deletions.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
CompressResult with compressed messages and metadata.
|
|
||||||
"""
|
|
||||||
# Guard clause for empty messages
|
|
||||||
if not messages:
|
|
||||||
return CompressResult(
|
|
||||||
messages=[],
|
|
||||||
token_count=0,
|
|
||||||
was_compacted=False,
|
|
||||||
original_token_count=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
token_model = _normalize_model_for_tokenizer(model)
|
|
||||||
enc = encoding_for_model(token_model)
|
|
||||||
msgs = deepcopy(messages)
|
|
||||||
|
|
||||||
def total_tokens() -> int:
|
|
||||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
|
||||||
|
|
||||||
original_count = total_tokens()
|
|
||||||
|
|
||||||
# Already under limit
|
|
||||||
if original_count + reserve <= target_tokens:
|
|
||||||
return CompressResult(
|
|
||||||
messages=msgs,
|
|
||||||
token_count=original_count,
|
|
||||||
was_compacted=False,
|
|
||||||
original_token_count=original_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages_summarized = 0
|
|
||||||
messages_dropped = 0
|
|
||||||
|
|
||||||
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
|
||||||
# This is the primary compression strategy for chat service.
|
|
||||||
# Summarize old messages while keeping recent ones intact.
|
|
||||||
if client is not None:
|
|
||||||
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
|
||||||
system_msg = msgs[0] if has_system else None
|
|
||||||
|
|
||||||
# Calculate old vs recent messages
|
|
||||||
if has_system:
|
|
||||||
if len(msgs) > keep_recent + 1:
|
|
||||||
old_msgs = msgs[1:-keep_recent]
|
|
||||||
recent_msgs = msgs[-keep_recent:]
|
|
||||||
else:
|
|
||||||
old_msgs = []
|
|
||||||
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
|
||||||
else:
|
|
||||||
if len(msgs) > keep_recent:
|
|
||||||
old_msgs = msgs[:-keep_recent]
|
|
||||||
recent_msgs = msgs[-keep_recent:]
|
|
||||||
else:
|
|
||||||
old_msgs = []
|
|
||||||
recent_msgs = msgs
|
|
||||||
|
|
||||||
# Ensure tool pairs stay intact
|
|
||||||
slice_start = max(0, len(msgs) - keep_recent)
|
|
||||||
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
|
||||||
|
|
||||||
if old_msgs:
|
|
||||||
try:
|
|
||||||
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
|
||||||
summary_msg = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
|
||||||
}
|
|
||||||
messages_summarized = len(old_msgs)
|
|
||||||
|
|
||||||
if has_system:
|
|
||||||
msgs = [system_msg, summary_msg] + recent_msgs
|
|
||||||
else:
|
|
||||||
msgs = [summary_msg] + recent_msgs
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
|
||||||
f"summarized {messages_summarized} messages"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
|
||||||
# Fall through to content truncation
|
|
||||||
|
|
||||||
# ---- STEP 2: Normalize content ----------------------------------------
|
|
||||||
# Convert non-string payloads to strings so token counting is coherent.
|
|
||||||
# Always run this before truncation to ensure consistent token counting.
|
|
||||||
for i, m in enumerate(msgs):
|
|
||||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
|
||||||
if _is_tool_message(m):
|
|
||||||
continue
|
|
||||||
if i == 0 or i == len(msgs) - 1:
|
|
||||||
continue
|
|
||||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
|
||||||
if len(content_str) > 20_000:
|
|
||||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
|
||||||
m["content"] = content_str
|
|
||||||
|
|
||||||
# ---- STEP 3: Token-aware content truncation ---------------------------
|
|
||||||
# Progressively halve per-message cap and truncate bloated content.
|
|
||||||
# This preserves all messages but shortens their content.
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for m in msgs[1:-1]:
|
|
||||||
if _is_tool_message(m):
|
|
||||||
_truncate_tool_message_content(m, enc, cap)
|
|
||||||
continue
|
|
||||||
if _is_objective_message(m):
|
|
||||||
continue
|
|
||||||
content = m.get("content") or ""
|
|
||||||
if _tok_len(content, enc) > cap:
|
|
||||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
|
||||||
cap //= 2
|
|
||||||
|
|
||||||
# ---- STEP 4: Middle-out deletion --------------------------------------
|
|
||||||
# Delete messages one at a time from the center outward.
|
|
||||||
# This is more granular than dropping all old messages at once.
|
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
|
||||||
deletable: list[int] = []
|
|
||||||
for i in range(1, len(msgs) - 1):
|
|
||||||
msg = msgs[i]
|
|
||||||
if (
|
|
||||||
msg is not None
|
|
||||||
and not _is_tool_message(msg)
|
|
||||||
and not _is_objective_message(msg)
|
|
||||||
):
|
|
||||||
deletable.append(i)
|
|
||||||
if not deletable:
|
|
||||||
break
|
|
||||||
centre = len(msgs) // 2
|
|
||||||
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
|
||||||
del msgs[to_delete]
|
|
||||||
messages_dropped += 1
|
|
||||||
|
|
||||||
# ---- STEP 5: Final trim on first/last ---------------------------------
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for idx in (0, -1):
|
|
||||||
msg = msgs[idx]
|
|
||||||
if msg is None:
|
|
||||||
continue
|
|
||||||
if _is_tool_message(msg):
|
|
||||||
_truncate_tool_message_content(msg, enc, cap)
|
|
||||||
continue
|
|
||||||
text = msg.get("content") or ""
|
|
||||||
if _tok_len(text, enc) > cap:
|
|
||||||
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
|
||||||
cap //= 2
|
|
||||||
|
|
||||||
# Filter out any None values that may have been introduced
|
|
||||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
|
||||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
|
||||||
error = None
|
|
||||||
if final_count + reserve > target_tokens:
|
|
||||||
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
|
||||||
logger.warning(error)
|
|
||||||
|
|
||||||
return CompressResult(
|
|
||||||
messages=final_msgs,
|
|
||||||
token_count=final_count,
|
|
||||||
was_compacted=True,
|
|
||||||
error=error,
|
|
||||||
original_token_count=original_count,
|
|
||||||
messages_summarized=messages_summarized,
|
|
||||||
messages_dropped=messages_dropped,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,21 +1,10 @@
|
|||||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.prompt import (
|
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||||
CompressResult,
|
|
||||||
_ensure_tool_pairs_intact,
|
|
||||||
_msg_tokens,
|
|
||||||
_normalize_model_for_tokenizer,
|
|
||||||
_truncate_middle_tokens,
|
|
||||||
_truncate_tool_message_content,
|
|
||||||
compress_context,
|
|
||||||
estimate_token_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMsgTokens:
|
class TestMsgTokens:
|
||||||
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
|
|||||||
|
|
||||||
assert total_tokens == expected_total
|
assert total_tokens == expected_total
|
||||||
assert total_tokens > 20 # Should be substantial
|
assert total_tokens > 20 # Should be substantial
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizeModelForTokenizer:
|
|
||||||
"""Test model name normalization for tiktoken."""
|
|
||||||
|
|
||||||
def test_openai_models_unchanged(self):
|
|
||||||
"""Test that OpenAI models are returned as-is."""
|
|
||||||
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
|
||||||
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
|
||||||
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
def test_claude_models_normalized(self):
|
|
||||||
"""Test that Claude models are normalized to gpt-4o."""
|
|
||||||
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
|
||||||
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
|
||||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
|
||||||
|
|
||||||
def test_openrouter_paths_extracted(self):
|
|
||||||
"""Test that OpenRouter model paths are handled."""
|
|
||||||
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
|
||||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
|
||||||
|
|
||||||
def test_unknown_models_default_to_gpt4o(self):
|
|
||||||
"""Test that unknown models default to gpt-4o."""
|
|
||||||
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
|
||||||
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
|
||||||
|
|
||||||
|
|
||||||
class TestTruncateToolMessageContent:
|
|
||||||
"""Test tool message content truncation."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def enc(self):
|
|
||||||
return encoding_for_model("gpt-4o")
|
|
||||||
|
|
||||||
def test_truncate_openai_tool_message(self, enc):
|
|
||||||
"""Test truncation of OpenAI-style tool message with string content."""
|
|
||||||
long_content = "x" * 10000
|
|
||||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
|
||||||
|
|
||||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
|
||||||
|
|
||||||
# Content should be truncated
|
|
||||||
assert len(msg["content"]) < len(long_content)
|
|
||||||
assert "…" in msg["content"] # Has ellipsis marker
|
|
||||||
|
|
||||||
def test_truncate_anthropic_tool_result(self, enc):
|
|
||||||
"""Test truncation of Anthropic-style tool_result."""
|
|
||||||
long_content = "y" * 10000
|
|
||||||
msg = {
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_123",
|
|
||||||
"content": long_content,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
|
||||||
|
|
||||||
# Content should be truncated
|
|
||||||
result_content = msg["content"][0]["content"]
|
|
||||||
assert len(result_content) < len(long_content)
|
|
||||||
assert "…" in result_content
|
|
||||||
|
|
||||||
def test_preserve_tool_use_blocks(self, enc):
|
|
||||||
"""Test that tool_use blocks are not truncated."""
|
|
||||||
msg = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_123",
|
|
||||||
"name": "some_function",
|
|
||||||
"input": {"key": "value" * 1000}, # Large input
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
original = json.dumps(msg["content"][0]["input"])
|
|
||||||
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
|
||||||
|
|
||||||
# tool_use should be unchanged
|
|
||||||
assert json.dumps(msg["content"][0]["input"]) == original
|
|
||||||
|
|
||||||
def test_no_truncation_when_under_limit(self, enc):
|
|
||||||
"""Test that short content is not modified."""
|
|
||||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
|
||||||
|
|
||||||
original = msg["content"]
|
|
||||||
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
|
||||||
|
|
||||||
assert msg["content"] == original
|
|
||||||
|
|
||||||
|
|
||||||
class TestTruncateMiddleTokens:
|
|
||||||
"""Test middle truncation of text."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def enc(self):
|
|
||||||
return encoding_for_model("gpt-4o")
|
|
||||||
|
|
||||||
def test_truncates_long_text(self, enc):
|
|
||||||
"""Test that long text is truncated with ellipsis in middle."""
|
|
||||||
long_text = "word " * 1000
|
|
||||||
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
|
||||||
|
|
||||||
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
|
||||||
assert "…" in result
|
|
||||||
assert result.startswith("word") # Head preserved
|
|
||||||
assert result.endswith("word ") # Tail preserved
|
|
||||||
|
|
||||||
def test_preserves_short_text(self, enc):
|
|
||||||
"""Test that short text is not modified."""
|
|
||||||
short_text = "Hello world"
|
|
||||||
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
|
||||||
|
|
||||||
assert result == short_text
|
|
||||||
|
|
||||||
|
|
||||||
class TestEnsureToolPairsIntact:
|
|
||||||
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
|
||||||
|
|
||||||
# ---- OpenAI Format Tests ----
|
|
||||||
|
|
||||||
def test_openai_adds_missing_tool_call(self):
|
|
||||||
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
|
||||||
{"role": "user", "content": "Thanks!"},
|
|
||||||
]
|
|
||||||
# Recent messages start at index 2 (the tool response)
|
|
||||||
recent = [all_msgs[2], all_msgs[3]]
|
|
||||||
start_index = 2
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the tool_call message
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["role"] == "assistant"
|
|
||||||
assert "tool_calls" in result[0]
|
|
||||||
|
|
||||||
def test_openai_keeps_complete_pairs(self):
|
|
||||||
"""Test that complete OpenAI pairs are unchanged."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
|
||||||
]
|
|
||||||
recent = all_msgs[1:] # Include both tool_call and response
|
|
||||||
start_index = 1
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
assert len(result) == 2 # No messages added
|
|
||||||
|
|
||||||
def test_openai_multiple_tool_calls(self):
|
|
||||||
"""Test multiple OpenAI tool calls in one assistant message."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
|
||||||
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
|
||||||
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
|
||||||
{"role": "user", "content": "Thanks!"},
|
|
||||||
]
|
|
||||||
# Recent messages start at index 2 (first tool response)
|
|
||||||
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
|
||||||
start_index = 2
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the assistant message with both tool_calls
|
|
||||||
assert len(result) == 4
|
|
||||||
assert result[0]["role"] == "assistant"
|
|
||||||
assert len(result[0]["tool_calls"]) == 2
|
|
||||||
|
|
||||||
# ---- Anthropic Format Tests ----
|
|
||||||
|
|
||||||
def test_anthropic_adds_missing_tool_use(self):
|
|
||||||
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_123",
|
|
||||||
"name": "get_weather",
|
|
||||||
"input": {"location": "SF"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_123",
|
|
||||||
"content": "22°C and sunny",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "Thanks!"},
|
|
||||||
]
|
|
||||||
# Recent messages start at index 2 (the tool_result)
|
|
||||||
recent = [all_msgs[2], all_msgs[3]]
|
|
||||||
start_index = 2
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the tool_use message
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["role"] == "assistant"
|
|
||||||
assert result[0]["content"][0]["type"] == "tool_use"
|
|
||||||
|
|
||||||
def test_anthropic_keeps_complete_pairs(self):
|
|
||||||
"""Test that complete Anthropic pairs are unchanged."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_456",
|
|
||||||
"name": "calculator",
|
|
||||||
"input": {"expr": "2+2"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_456",
|
|
||||||
"content": "4",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
recent = all_msgs[1:] # Include both tool_use and result
|
|
||||||
start_index = 1
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
assert len(result) == 2 # No messages added
|
|
||||||
|
|
||||||
def test_anthropic_multiple_tool_uses(self):
|
|
||||||
"""Test multiple Anthropic tool_use blocks in one message."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "Let me check both..."},
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_1",
|
|
||||||
"name": "get_weather",
|
|
||||||
"input": {"city": "NYC"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_2",
|
|
||||||
"name": "get_weather",
|
|
||||||
"input": {"city": "LA"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_1",
|
|
||||||
"content": "Cold",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_2",
|
|
||||||
"content": "Warm",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "Thanks!"},
|
|
||||||
]
|
|
||||||
# Recent messages start at index 2 (tool_result)
|
|
||||||
recent = [all_msgs[2], all_msgs[3]]
|
|
||||||
start_index = 2
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the assistant message with both tool_uses
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["role"] == "assistant"
|
|
||||||
tool_use_count = sum(
|
|
||||||
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
|
||||||
)
|
|
||||||
assert tool_use_count == 2
|
|
||||||
|
|
||||||
# ---- Mixed/Edge Case Tests ----
|
|
||||||
|
|
||||||
def test_anthropic_with_type_message_field(self):
|
|
||||||
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_abc",
|
|
||||||
"name": "search",
|
|
||||||
"input": {"q": "test"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"type": "message", # Extra field from smart_decision_maker
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_abc",
|
|
||||||
"content": "Found results",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "user", "content": "Thanks!"},
|
|
||||||
]
|
|
||||||
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
|
||||||
recent = [all_msgs[2], all_msgs[3]]
|
|
||||||
start_index = 2
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the tool_use message
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["role"] == "assistant"
|
|
||||||
assert result[0]["content"][0]["type"] == "tool_use"
|
|
||||||
|
|
||||||
def test_handles_no_tool_messages(self):
|
|
||||||
"""Test messages without tool calls."""
|
|
||||||
all_msgs = [
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
recent = all_msgs
|
|
||||||
start_index = 0
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
assert result == all_msgs
|
|
||||||
|
|
||||||
def test_handles_empty_messages(self):
|
|
||||||
"""Test empty message list."""
|
|
||||||
result = _ensure_tool_pairs_intact([], [], 0)
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_mixed_text_and_tool_content(self):
|
|
||||||
"""Test Anthropic message with mixed text and tool_use content."""
|
|
||||||
all_msgs = [
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "I'll help you with that."},
|
|
||||||
{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": "toolu_mixed",
|
|
||||||
"name": "search",
|
|
||||||
"input": {"q": "test"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_mixed",
|
|
||||||
"content": "Found results",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Here are the results..."},
|
|
||||||
]
|
|
||||||
# Start from tool_result
|
|
||||||
recent = [all_msgs[1], all_msgs[2]]
|
|
||||||
start_index = 1
|
|
||||||
|
|
||||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
|
||||||
|
|
||||||
# Should prepend the assistant message with tool_use
|
|
||||||
assert len(result) == 3
|
|
||||||
assert result[0]["content"][0]["type"] == "text"
|
|
||||||
assert result[0]["content"][1]["type"] == "tool_use"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompressContext:
|
|
||||||
"""Test the async compress_context function."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_compression_needed(self):
|
|
||||||
"""Test messages under limit return without compression."""
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hello!"},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await compress_context(messages, target_tokens=100000)
|
|
||||||
|
|
||||||
assert isinstance(result, CompressResult)
|
|
||||||
assert result.was_compacted is False
|
|
||||||
assert len(result.messages) == 2
|
|
||||||
assert result.error is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_truncation_without_client(self):
|
|
||||||
"""Test that truncation works without LLM client."""
|
|
||||||
long_content = "x" * 50000
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{"role": "user", "content": long_content},
|
|
||||||
{"role": "assistant", "content": "Response"},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await compress_context(
|
|
||||||
messages, target_tokens=1000, client=None, reserve=100
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.was_compacted is True
|
|
||||||
# Should have truncated without summarization
|
|
||||||
assert result.messages_summarized == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_with_mocked_llm_client(self):
|
|
||||||
"""Test summarization with mocked LLM client."""
|
|
||||||
# Create many messages to trigger summarization
|
|
||||||
messages = [{"role": "system", "content": "System prompt"}]
|
|
||||||
for i in range(30):
|
|
||||||
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the AsyncOpenAI client
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.choices = [MagicMock()]
|
|
||||||
mock_response.choices[0].message.content = "Summary of conversation"
|
|
||||||
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
|
||||||
return_value=mock_response
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await compress_context(
|
|
||||||
messages,
|
|
||||||
target_tokens=5000,
|
|
||||||
client=mock_client,
|
|
||||||
keep_recent=5,
|
|
||||||
reserve=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.was_compacted is True
|
|
||||||
# Should have attempted summarization
|
|
||||||
assert mock_client.with_options.called or result.messages_summarized > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preserves_tool_pairs(self):
|
|
||||||
"""Test that tool call/response pairs stay together."""
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "System"},
|
|
||||||
{"role": "user", "content": "Do something"},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"tool_calls": [
|
|
||||||
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
|
||||||
{"role": "assistant", "content": "Done!"},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await compress_context(
|
|
||||||
messages, target_tokens=500, client=None, reserve=50
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that if tool response exists, its call exists too
|
|
||||||
tool_call_ids = set()
|
|
||||||
tool_response_ids = set()
|
|
||||||
for msg in result.messages:
|
|
||||||
if "tool_calls" in msg:
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tool_call_ids.add(tc["id"])
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tool_response_ids.add(msg.get("tool_call_id"))
|
|
||||||
|
|
||||||
# All tool responses should have their calls
|
|
||||||
assert tool_response_ids <= tool_call_ids
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_error_when_cannot_compress(self):
|
|
||||||
"""Test that error is returned when compression fails."""
|
|
||||||
# Single huge message that can't be compressed enough
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "x" * 100000},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await compress_context(
|
|
||||||
messages, target_tokens=100, client=None, reserve=50
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should have an error since we can't get below 100 tokens
|
|
||||||
assert result.error is not None
|
|
||||||
assert result.was_compacted is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_messages(self):
|
|
||||||
"""Test that empty messages list returns early without error."""
|
|
||||||
result = await compress_context([], target_tokens=1000)
|
|
||||||
|
|
||||||
assert result.messages == []
|
|
||||||
assert result.token_count == 0
|
|
||||||
assert result.was_compacted is False
|
|
||||||
assert result.error is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestRemoveOrphanToolResponses:
|
|
||||||
"""Test _remove_orphan_tool_responses helper function."""
|
|
||||||
|
|
||||||
def test_removes_openai_orphan(self):
|
|
||||||
"""Test removal of orphan OpenAI tool response."""
|
|
||||||
from backend.util.prompt import _remove_orphan_tool_responses
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
]
|
|
||||||
orphan_ids = {"call_orphan"}
|
|
||||||
|
|
||||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["role"] == "user"
|
|
||||||
|
|
||||||
def test_keeps_valid_openai_tool(self):
|
|
||||||
"""Test that valid OpenAI tool responses are kept."""
|
|
||||||
from backend.util.prompt import _remove_orphan_tool_responses
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
|
||||||
]
|
|
||||||
orphan_ids = {"call_other"}
|
|
||||||
|
|
||||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["tool_call_id"] == "call_valid"
|
|
||||||
|
|
||||||
def test_filters_anthropic_mixed_blocks(self):
|
|
||||||
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
|
||||||
from backend.util.prompt import _remove_orphan_tool_responses
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_valid",
|
|
||||||
"content": "valid result",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_orphan",
|
|
||||||
"content": "orphan result",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
orphan_ids = {"toolu_orphan"}
|
|
||||||
|
|
||||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
assert len(result) == 1
|
|
||||||
# Should only have the valid tool_result, orphan filtered out
|
|
||||||
assert len(result[0]["content"]) == 1
|
|
||||||
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
|
||||||
|
|
||||||
def test_removes_anthropic_all_orphan(self):
|
|
||||||
"""Test removal of Anthropic message when all tool_results are orphans."""
|
|
||||||
from backend.util.prompt import _remove_orphan_tool_responses
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_orphan1",
|
|
||||||
"content": "result1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": "toolu_orphan2",
|
|
||||||
"content": "result2",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
|
||||||
|
|
||||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
# Message should be completely removed since no content left
|
|
||||||
assert len(result) == 0
|
|
||||||
|
|
||||||
def test_preserves_non_tool_messages(self):
|
|
||||||
"""Test that non-tool messages are preserved."""
|
|
||||||
from backend.util.prompt import _remove_orphan_tool_responses
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "Hello"},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
orphan_ids = {"some_id"}
|
|
||||||
|
|
||||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
|
||||||
|
|
||||||
assert result == messages
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompressResultDataclass:
|
|
||||||
"""Test CompressResult dataclass."""
|
|
||||||
|
|
||||||
def test_default_values(self):
|
|
||||||
"""Test default values are set correctly."""
|
|
||||||
result = CompressResult(
|
|
||||||
messages=[{"role": "user", "content": "test"}],
|
|
||||||
token_count=10,
|
|
||||||
was_compacted=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.error is None
|
|
||||||
assert result.original_token_count == 0 # Defaults to 0, not None
|
|
||||||
assert result.messages_summarized == 0
|
|
||||||
assert result.messages_dropped == 0
|
|
||||||
|
|
||||||
def test_all_fields(self):
|
|
||||||
"""Test all fields can be set."""
|
|
||||||
result = CompressResult(
|
|
||||||
messages=[{"role": "user", "content": "test"}],
|
|
||||||
token_count=100,
|
|
||||||
was_compacted=True,
|
|
||||||
error="Some error",
|
|
||||||
original_token_count=500,
|
|
||||||
messages_summarized=10,
|
|
||||||
messages_dropped=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.token_count == 100
|
|
||||||
assert result.was_compacted is True
|
|
||||||
assert result.error == "Some error"
|
|
||||||
assert result.original_token_count == 500
|
|
||||||
assert result.messages_summarized == 10
|
|
||||||
assert result.messages_dropped == 5
|
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ class TestGenerateAgent:
|
|||||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||||
result = await core.generate_agent(instructions)
|
result = await core.generate_agent(instructions)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
# library_agents defaults to None
|
||||||
|
mock_external.assert_called_once_with(instructions, None)
|
||||||
|
# Result should have id, version, is_active added if not present
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["name"] == "Test Agent"
|
assert result["name"] == "Test Agent"
|
||||||
assert "id" in result
|
assert "id" in result
|
||||||
@@ -175,9 +177,8 @@ class TestGenerateAgentPatch:
|
|||||||
current_agent = {"nodes": [], "links": []}
|
current_agent = {"nodes": [], "links": []}
|
||||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||||
|
|
||||||
mock_external.assert_called_once_with(
|
# library_agents defaults to None
|
||||||
"Add a node", current_agent, None, None, None
|
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||||
)
|
|
||||||
assert result == expected_result
|
assert result == expected_result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_with_context(self):
|
async def test_decompose_goal_with_context(self):
|
||||||
"""Test decomposition with additional context enriched into description."""
|
"""Test decomposition with additional context."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
|
|||||||
"Build a chatbot", context="Use Python"
|
"Build a chatbot", context="Use Python"
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_description = (
|
|
||||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
|
||||||
)
|
|
||||||
mock_client.post.assert_called_once_with(
|
mock_client.post.assert_called_once_with(
|
||||||
"/api/decompose-description",
|
"/api/decompose-description",
|
||||||
json={"description": expected_description},
|
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"use client";
|
"use client";
|
||||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
|
||||||
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
|
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
|
||||||
export default function OnboardingPage() {
|
export default function OnboardingPage() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
|
|||||||
async function redirectToStep() {
|
async function redirectToStep() {
|
||||||
try {
|
try {
|
||||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
if (!shouldShowOnboarding) {
|
if (!shouldShowOnboarding) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
|
|||||||
|
|
||||||
// Handle completed onboarding
|
// Handle completed onboarding
|
||||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||||
router.replace("/");
|
router.replace(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { revalidatePath } from "next/cache";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
import { revalidatePath } from "next/cache";
|
||||||
|
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||||
|
|
||||||
// Handle the callback to complete the user session login
|
// Handle the callback to complete the user session login
|
||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } =
|
||||||
|
await getOnboardingStatus();
|
||||||
if (shouldShowOnboarding) {
|
if (shouldShowOnboarding) {
|
||||||
next = "/onboarding";
|
next = "/onboarding";
|
||||||
revalidatePath("/onboarding", "layout");
|
revalidatePath("/onboarding", "layout");
|
||||||
} else {
|
} else {
|
||||||
next = "/";
|
next = getHomepageRoute(isChatEnabled);
|
||||||
revalidatePath(next, "layout");
|
revalidatePath(next, "layout");
|
||||||
}
|
}
|
||||||
} catch (createUserError) {
|
} catch (createUserError) {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
|||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { usePathname, useSearchParams } from "next/navigation";
|
import { usePathname, useSearchParams } from "next/navigation";
|
||||||
|
import { useRef } from "react";
|
||||||
import { useCopilotStore } from "../../copilot-page-store";
|
import { useCopilotStore } from "../../copilot-page-store";
|
||||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||||
@@ -69,16 +70,41 @@ export function useCopilotShell() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const stopStream = useChatStore((s) => s.stopStream);
|
const stopStream = useChatStore((s) => s.stopStream);
|
||||||
|
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||||
|
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||||
|
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||||
|
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||||
|
|
||||||
function handleSessionClick(sessionId: string) {
|
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||||
if (sessionId === currentSessionId) return;
|
|
||||||
|
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
async function stopCurrentStream() {
|
||||||
if (currentSessionId) {
|
if (!currentSessionId) return;
|
||||||
|
|
||||||
|
setIsSwitchingSession(true);
|
||||||
|
await new Promise<void>((resolve) => {
|
||||||
|
const unsubscribe = onStreamComplete((completedId) => {
|
||||||
|
if (completedId === currentSessionId) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const timeout = setTimeout(() => {
|
||||||
|
unsubscribe();
|
||||||
|
resolve();
|
||||||
|
}, 3000);
|
||||||
stopStream(currentSessionId);
|
stopStream(currentSessionId);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||||
|
});
|
||||||
|
setIsSwitchingSession(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function selectSession(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||||
@@ -88,12 +114,7 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleNewChatClick() {
|
function startNewChat() {
|
||||||
// Stop current stream - SSE reconnection allows resuming later
|
|
||||||
if (currentSessionId) {
|
|
||||||
stopStream(currentSessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
resetPagination();
|
resetPagination();
|
||||||
queryClient.invalidateQueries({
|
queryClient.invalidateQueries({
|
||||||
queryKey: getGetV2ListSessionsQueryKey(),
|
queryKey: getGetV2ListSessionsQueryKey(),
|
||||||
@@ -102,6 +123,32 @@ export function useCopilotShell() {
|
|||||||
if (isMobile) handleCloseDrawer();
|
if (isMobile) handleCloseDrawer();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function handleSessionClick(sessionId: string) {
|
||||||
|
if (sessionId === currentSessionId) return;
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
selectSession(sessionId);
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
selectSession(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleNewChatClick() {
|
||||||
|
if (isStreaming) {
|
||||||
|
pendingActionRef.current = async () => {
|
||||||
|
await stopCurrentStream();
|
||||||
|
startNewChat();
|
||||||
|
};
|
||||||
|
openInterruptModal(pendingActionRef.current);
|
||||||
|
} else {
|
||||||
|
startNewChat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isMobile,
|
isMobile,
|
||||||
isDrawerOpen,
|
isDrawerOpen,
|
||||||
|
|||||||
@@ -1,13 +1,6 @@
|
|||||||
"use client";
|
import type { ReactNode } from "react";
|
||||||
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
|
||||||
import { Flag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import { type ReactNode } from "react";
|
|
||||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||||
|
|
||||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||||
return (
|
return <CopilotShell>{children}</CopilotShell>;
|
||||||
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
|
||||||
<CopilotShell>{children}</CopilotShell>
|
|
||||||
</FeatureFlagPage>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,14 @@ export default function CopilotPage() {
|
|||||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
const {
|
||||||
state;
|
greetingName,
|
||||||
|
quickActions,
|
||||||
|
isLoading,
|
||||||
|
hasSession,
|
||||||
|
initialPrompt,
|
||||||
|
isReady,
|
||||||
|
} = state;
|
||||||
const {
|
const {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
@@ -23,6 +29,8 @@ export default function CopilotPage() {
|
|||||||
handleStreamingChange,
|
handleStreamingChange,
|
||||||
} = handlers;
|
} = handlers;
|
||||||
|
|
||||||
|
if (!isReady) return null;
|
||||||
|
|
||||||
if (hasSession) {
|
if (hasSession) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-col">
|
<div className="flex h-full flex-col">
|
||||||
|
|||||||
@@ -3,11 +3,18 @@ import {
|
|||||||
postV2CreateSession,
|
postV2CreateSession,
|
||||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
|
import {
|
||||||
|
Flag,
|
||||||
|
type FlagValues,
|
||||||
|
useGetFlag,
|
||||||
|
} from "@/services/feature-flags/use-get-flag";
|
||||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
@@ -26,6 +33,22 @@ export function useCopilotPage() {
|
|||||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||||
|
|
||||||
|
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||||
|
useEffect(() => {
|
||||||
|
if (isLoggedIn) {
|
||||||
|
completeStep("VISIT_COPILOT");
|
||||||
|
}
|
||||||
|
}, [completeStep, isLoggedIn]);
|
||||||
|
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const flags = useFlags<FlagValues>();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
const isFlagReady =
|
||||||
|
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||||
|
|
||||||
const greetingName = getGreetingName(user);
|
const greetingName = getGreetingName(user);
|
||||||
const quickActions = getQuickActions();
|
const quickActions = getQuickActions();
|
||||||
|
|
||||||
@@ -35,8 +58,11 @@ export function useCopilotPage() {
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn) completeStep("VISIT_COPILOT");
|
if (!isFlagReady) return;
|
||||||
}, [completeStep, isLoggedIn]);
|
if (isChatEnabled === false) {
|
||||||
|
router.replace(homepageRoute);
|
||||||
|
}
|
||||||
|
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||||
|
|
||||||
async function startChatWithPrompt(prompt: string) {
|
async function startChatWithPrompt(prompt: string) {
|
||||||
if (!prompt?.trim()) return;
|
if (!prompt?.trim()) return;
|
||||||
@@ -90,6 +116,7 @@ export function useCopilotPage() {
|
|||||||
isLoading: isUserLoading,
|
isLoading: isUserLoading,
|
||||||
hasSession,
|
hasSession,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
|
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||||
},
|
},
|
||||||
handlers: {
|
handlers: {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import { Suspense } from "react";
|
import { Suspense } from "react";
|
||||||
import { getErrorDetails } from "./helpers";
|
import { getErrorDetails } from "./helpers";
|
||||||
@@ -9,6 +11,8 @@ function ErrorPageContent() {
|
|||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const errorMessage = searchParams.get("message");
|
const errorMessage = searchParams.get("message");
|
||||||
const errorDetails = getErrorDetails(errorMessage);
|
const errorDetails = getErrorDetails(errorMessage);
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
function handleRetry() {
|
function handleRetry() {
|
||||||
// Auth-related errors should redirect to login
|
// Auth-related errors should redirect to login
|
||||||
@@ -26,7 +30,7 @@ function ErrorPageContent() {
|
|||||||
}, 2000);
|
}, 2000);
|
||||||
} else {
|
} else {
|
||||||
// For server/network errors, go to home
|
// For server/network errors, go to home
|
||||||
window.location.href = "/";
|
window.location.href = homepageRoute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import BackendAPI from "@/lib/autogpt-server-api";
|
import BackendAPI from "@/lib/autogpt-server-api";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { loginFormSchema } from "@/types/auth";
|
import { loginFormSchema } from "@/types/auth";
|
||||||
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
|
|||||||
await api.createUser();
|
await api.createUser();
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -20,15 +22,17 @@ export function useLoginPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isLoggingIn) {
|
if (isLoggedIn && !isLoggingIn) {
|
||||||
router.push(nextUrl || "/");
|
router.push(nextUrl || homepageRoute);
|
||||||
}
|
}
|
||||||
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||||
resolver: zodResolver(loginFormSchema),
|
resolver: zodResolver(loginFormSchema),
|
||||||
@@ -94,7 +98,7 @@ export function useLoginPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer URL's next parameter, then use backend-determined route
|
// Prefer URL's next parameter, then use backend-determined route
|
||||||
router.replace(nextUrl || result.next || "/");
|
router.replace(nextUrl || result.next || homepageRoute);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
toast({
|
toast({
|
||||||
title:
|
title:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"use server";
|
"use server";
|
||||||
|
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||||
import { signupFormSchema } from "@/types/auth";
|
import { signupFormSchema } from "@/types/auth";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
@@ -58,8 +59,10 @@ export async function signup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
const next = shouldShowOnboarding
|
||||||
|
? "/onboarding"
|
||||||
|
: getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
return { success: true, next };
|
return { success: true, next };
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useRouter, useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
@@ -20,15 +22,17 @@ export function useSignupPage() {
|
|||||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||||
const isCloudEnv = environment.isCloud();
|
const isCloudEnv = environment.isCloud();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
// Get redirect destination from 'next' query parameter
|
// Get redirect destination from 'next' query parameter
|
||||||
const nextUrl = searchParams.get("next");
|
const nextUrl = searchParams.get("next");
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isLoggedIn && !isSigningUp) {
|
if (isLoggedIn && !isSigningUp) {
|
||||||
router.push(nextUrl || "/");
|
router.push(nextUrl || homepageRoute);
|
||||||
}
|
}
|
||||||
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
||||||
|
|
||||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||||
resolver: zodResolver(signupFormSchema),
|
resolver: zodResolver(signupFormSchema),
|
||||||
@@ -129,7 +133,7 @@ export function useSignupPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||||
const redirectTo = nextUrl || result.next || "/";
|
const redirectTo = nextUrl || result.next || homepageRoute;
|
||||||
router.replace(redirectTo);
|
router.replace(redirectTo);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setIsLoading(false);
|
setIsLoading(false);
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
import { environment } from "@/services/environment";
|
|
||||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
|
||||||
import { NextRequest } from "next/server";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SSE Proxy for task stream reconnection.
|
|
||||||
*
|
|
||||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
|
||||||
* background task's stream. It replays missed messages from Redis Streams and
|
|
||||||
* subscribes to live updates if the task is still running.
|
|
||||||
*
|
|
||||||
* Client contract:
|
|
||||||
* 1. When receiving an operation_started event, store the task_id
|
|
||||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
|
||||||
* 3. Messages are replayed from the last_message_id position
|
|
||||||
* 4. Stream ends when "finish" event is received
|
|
||||||
*/
|
|
||||||
export async function GET(
|
|
||||||
request: NextRequest,
|
|
||||||
{ params }: { params: Promise<{ taskId: string }> },
|
|
||||||
) {
|
|
||||||
const { taskId } = await params;
|
|
||||||
const searchParams = request.nextUrl.searchParams;
|
|
||||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Get auth token from server-side session
|
|
||||||
const token = await getServerAuthToken();
|
|
||||||
|
|
||||||
// Build backend URL
|
|
||||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
|
||||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
|
||||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
|
||||||
|
|
||||||
// Forward request to backend with auth header
|
|
||||||
const headers: Record<string, string> = {
|
|
||||||
Accept: "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
Connection: "keep-alive",
|
|
||||||
};
|
|
||||||
|
|
||||||
if (token) {
|
|
||||||
headers["Authorization"] = `Bearer ${token}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(streamUrl.toString(), {
|
|
||||||
method: "GET",
|
|
||||||
headers,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
const error = await response.text();
|
|
||||||
return new Response(error, {
|
|
||||||
status: response.status,
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the SSE stream directly
|
|
||||||
return new Response(response.body, {
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache, no-transform",
|
|
||||||
Connection: "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Task stream proxy error:", error);
|
|
||||||
return new Response(
|
|
||||||
JSON.stringify({
|
|
||||||
error: "Failed to connect to task stream",
|
|
||||||
detail: error instanceof Error ? error.message : String(error),
|
|
||||||
}),
|
|
||||||
{
|
|
||||||
status: 500,
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
|
|||||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||||
return {
|
return {
|
||||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||||
|
isChatEnabled: status.is_chat_enabled,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -917,28 +917,6 @@
|
|||||||
"security": [{ "HTTPBearerJWT": [] }]
|
"security": [{ "HTTPBearerJWT": [] }]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/config/ttl": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Get Ttl Config",
|
|
||||||
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
|
|
||||||
"operationId": "getV2GetTtlConfig",
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Response Getv2Getttlconfig"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/health": {
|
"/api/chat/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -961,63 +939,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/operations/{operation_id}/complete": {
|
|
||||||
"post": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Complete Operation",
|
|
||||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
|
||||||
"operationId": "postV2CompleteOperation",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "operation_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Operation Id" }
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "x-api-key",
|
|
||||||
"in": "header",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "X-Api-Key"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"requestBody": {
|
|
||||||
"required": true,
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": true,
|
|
||||||
"title": "Response Postv2Completeoperation"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/sessions": {
|
"/api/chat/sessions": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
@@ -1101,7 +1022,7 @@
|
|||||||
"get": {
|
"get": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Get Session",
|
"summary": "Get Session",
|
||||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||||
"operationId": "getV2GetSession",
|
"operationId": "getV2GetSession",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1236,7 +1157,7 @@
|
|||||||
"post": {
|
"post": {
|
||||||
"tags": ["v2", "chat", "chat"],
|
"tags": ["v2", "chat", "chat"],
|
||||||
"summary": "Stream Chat Post",
|
"summary": "Stream Chat Post",
|
||||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||||
"operationId": "postV2StreamChatPost",
|
"operationId": "postV2StreamChatPost",
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
"security": [{ "HTTPBearerJWT": [] }],
|
||||||
"parameters": [
|
"parameters": [
|
||||||
@@ -1274,94 +1195,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/chat/tasks/{task_id}": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Get Task Status",
|
|
||||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
|
||||||
"operationId": "getV2GetTaskStatus",
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "task_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Task Id" }
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"type": "object",
|
|
||||||
"additionalProperties": true,
|
|
||||||
"title": "Response Getv2Gettaskstatus"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/chat/tasks/{task_id}/stream": {
|
|
||||||
"get": {
|
|
||||||
"tags": ["v2", "chat", "chat"],
|
|
||||||
"summary": "Stream Task",
|
|
||||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
|
|
||||||
"operationId": "getV2StreamTask",
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }],
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "task_id",
|
|
||||||
"in": "path",
|
|
||||||
"required": true,
|
|
||||||
"schema": { "type": "string", "title": "Task Id" }
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "last_message_id",
|
|
||||||
"in": "query",
|
|
||||||
"required": false,
|
|
||||||
"schema": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
|
||||||
"default": "0-0",
|
|
||||||
"title": "Last Message Id"
|
|
||||||
},
|
|
||||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": { "application/json": { "schema": {} } }
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/credits": {
|
"/api/credits": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["v1", "credits"],
|
"tags": ["v1", "credits"],
|
||||||
@@ -6335,18 +6168,6 @@
|
|||||||
"title": "AccuracyTrendsResponse",
|
"title": "AccuracyTrendsResponse",
|
||||||
"description": "Response model for accuracy trends and alerts."
|
"description": "Response model for accuracy trends and alerts."
|
||||||
},
|
},
|
||||||
"ActiveStreamInfo": {
|
|
||||||
"properties": {
|
|
||||||
"task_id": { "type": "string", "title": "Task Id" },
|
|
||||||
"last_message_id": { "type": "string", "title": "Last Message Id" },
|
|
||||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
|
||||||
"tool_name": { "type": "string", "title": "Tool Name" }
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
|
|
||||||
"title": "ActiveStreamInfo",
|
|
||||||
"description": "Information about an active stream for reconnection."
|
|
||||||
},
|
|
||||||
"AddUserCreditsResponse": {
|
"AddUserCreditsResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||||
@@ -9002,27 +8823,6 @@
|
|||||||
],
|
],
|
||||||
"title": "OnboardingStep"
|
"title": "OnboardingStep"
|
||||||
},
|
},
|
||||||
"OperationCompleteRequest": {
|
|
||||||
"properties": {
|
|
||||||
"success": { "type": "boolean", "title": "Success" },
|
|
||||||
"result": {
|
|
||||||
"anyOf": [
|
|
||||||
{ "additionalProperties": true, "type": "object" },
|
|
||||||
{ "type": "string" },
|
|
||||||
{ "type": "null" }
|
|
||||||
],
|
|
||||||
"title": "Result"
|
|
||||||
},
|
|
||||||
"error": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Error"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["success"],
|
|
||||||
"title": "OperationCompleteRequest",
|
|
||||||
"description": "Request model for external completion webhook."
|
|
||||||
},
|
|
||||||
"Pagination": {
|
"Pagination": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"total_items": {
|
"total_items": {
|
||||||
@@ -9878,12 +9678,6 @@
|
|||||||
"items": { "additionalProperties": true, "type": "object" },
|
"items": { "additionalProperties": true, "type": "object" },
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"title": "Messages"
|
"title": "Messages"
|
||||||
},
|
|
||||||
"active_stream": {
|
|
||||||
"anyOf": [
|
|
||||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
|
||||||
{ "type": "null" }
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@@ -1,15 +1,27 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
|
|
||||||
export default function Page() {
|
export default function Page() {
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
const isFlagReady =
|
||||||
|
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(
|
||||||
router.replace("/copilot");
|
function redirectToHomepage() {
|
||||||
}, [router]);
|
if (!isFlagReady) return;
|
||||||
|
router.replace(homepageRoute);
|
||||||
|
},
|
||||||
|
[homepageRoute, isFlagReady, router],
|
||||||
|
);
|
||||||
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||||
|
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
@@ -24,8 +25,8 @@ export function Chat({
|
|||||||
}: ChatProps) {
|
}: ChatProps) {
|
||||||
const { urlSessionId } = useCopilotSessionId();
|
const { urlSessionId } = useCopilotSessionId();
|
||||||
const hasHandledNotFoundRef = useRef(false);
|
const hasHandledNotFoundRef = useRef(false);
|
||||||
|
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||||
const {
|
const {
|
||||||
session,
|
|
||||||
messages,
|
messages,
|
||||||
isLoading,
|
isLoading,
|
||||||
isCreating,
|
isCreating,
|
||||||
@@ -37,18 +38,6 @@ export function Chat({
|
|||||||
startPollingForOperation,
|
startPollingForOperation,
|
||||||
} = useChat({ urlSessionId });
|
} = useChat({ urlSessionId });
|
||||||
|
|
||||||
// Extract active stream info for reconnection
|
|
||||||
const activeStream = (
|
|
||||||
session as {
|
|
||||||
active_stream?: {
|
|
||||||
task_id: string;
|
|
||||||
last_message_id: string;
|
|
||||||
operation_id: string;
|
|
||||||
tool_name: string;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
)?.active_stream;
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!onSessionNotFound) return;
|
if (!onSessionNotFound) return;
|
||||||
if (!urlSessionId) return;
|
if (!urlSessionId) return;
|
||||||
@@ -64,7 +53,8 @@ export function Chat({
|
|||||||
isCreating,
|
isCreating,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
const shouldShowLoader =
|
||||||
|
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("flex h-full flex-col", className)}>
|
<div className={cn("flex h-full flex-col", className)}>
|
||||||
@@ -76,19 +66,21 @@ export function Chat({
|
|||||||
<div className="flex flex-col items-center gap-3">
|
<div className="flex flex-col items-center gap-3">
|
||||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||||
<Text variant="body" className="text-zinc-500">
|
<Text variant="body" className="text-zinc-500">
|
||||||
Loading your chat...
|
{isSwitchingSession
|
||||||
|
? "Switching chat..."
|
||||||
|
: "Loading your chat..."}
|
||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Error State */}
|
{/* Error State */}
|
||||||
{error && !isLoading && (
|
{error && !isLoading && !isSwitchingSession && (
|
||||||
<ChatErrorState error={error} onRetry={createSession} />
|
<ChatErrorState error={error} onRetry={createSession} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Session Content */}
|
{/* Session Content */}
|
||||||
{sessionId && !isLoading && !error && (
|
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||||
<ChatContainer
|
<ChatContainer
|
||||||
sessionId={sessionId}
|
sessionId={sessionId}
|
||||||
initialMessages={messages}
|
initialMessages={messages}
|
||||||
@@ -96,16 +88,6 @@ export function Chat({
|
|||||||
className="flex-1"
|
className="flex-1"
|
||||||
onStreamingChange={onStreamingChange}
|
onStreamingChange={onStreamingChange}
|
||||||
onOperationStarted={startPollingForOperation}
|
onOperationStarted={startPollingForOperation}
|
||||||
activeStream={
|
|
||||||
activeStream
|
|
||||||
? {
|
|
||||||
taskId: activeStream.task_id,
|
|
||||||
lastMessageId: activeStream.last_message_id,
|
|
||||||
operationId: activeStream.operation_id,
|
|
||||||
toolName: activeStream.tool_name,
|
|
||||||
}
|
|
||||||
: undefined
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -1,159 +0,0 @@
|
|||||||
# SSE Reconnection Contract for Long-Running Operations
|
|
||||||
|
|
||||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
When a user triggers a long-running operation (like agent generation), the backend:
|
|
||||||
|
|
||||||
1. Spawns a background task that survives SSE disconnections
|
|
||||||
2. Returns an `operation_started` response with a `task_id`
|
|
||||||
3. Stores stream messages in Redis Streams for replay
|
|
||||||
|
|
||||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
|
||||||
|
|
||||||
## Client-Side Flow
|
|
||||||
|
|
||||||
### 1. Receiving Operation Started
|
|
||||||
|
|
||||||
When you receive an `operation_started` tool response:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// The response includes a task_id for reconnection
|
|
||||||
{
|
|
||||||
type: "operation_started",
|
|
||||||
tool_name: "generate_agent",
|
|
||||||
operation_id: "uuid-...",
|
|
||||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
|
||||||
message: "Operation started. You can close this tab."
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Storing Task Info
|
|
||||||
|
|
||||||
Use the chat store to track the active task:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
import { useChatStore } from "./chat-store";
|
|
||||||
|
|
||||||
// When operation_started is received:
|
|
||||||
useChatStore.getState().setActiveTask(sessionId, {
|
|
||||||
taskId: response.task_id,
|
|
||||||
operationId: response.operation_id,
|
|
||||||
toolName: response.tool_name,
|
|
||||||
lastMessageId: "0",
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Reconnecting to a Task
|
|
||||||
|
|
||||||
To reconnect (e.g., after page refresh or tab reopen):
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
|
||||||
|
|
||||||
// Check if there's an active task for this session
|
|
||||||
const activeTask = getActiveTask(sessionId);
|
|
||||||
|
|
||||||
if (activeTask) {
|
|
||||||
// Reconnect to the task stream
|
|
||||||
await reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeTask.taskId,
|
|
||||||
activeTask.lastMessageId, // Resume from last position
|
|
||||||
(chunk) => {
|
|
||||||
// Handle incoming chunks
|
|
||||||
console.log("Received chunk:", chunk);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Tracking Message Position
|
|
||||||
|
|
||||||
To enable precise replay, update the last message ID as chunks arrive:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
|
||||||
|
|
||||||
function handleChunk(chunk: StreamChunk) {
|
|
||||||
// If chunk has an index/id, track it
|
|
||||||
if (chunk.idx !== undefined) {
|
|
||||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
### Task Stream Reconnection
|
|
||||||
|
|
||||||
```
|
|
||||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
|
||||||
```
|
|
||||||
|
|
||||||
- `taskId`: The task ID from `operation_started`
|
|
||||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
|
||||||
|
|
||||||
Returns: SSE stream of missed messages + live updates
|
|
||||||
|
|
||||||
## Chunk Types
|
|
||||||
|
|
||||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
|
||||||
|
|
||||||
| Type | Description |
|
|
||||||
| ----------------------- | ----------------------- |
|
|
||||||
| `start` | Message lifecycle start |
|
|
||||||
| `text-delta` | Streaming text content |
|
|
||||||
| `text-end` | Text block completed |
|
|
||||||
| `tool-output-available` | Tool result available |
|
|
||||||
| `finish` | Stream completed |
|
|
||||||
| `error` | Error occurred |
|
|
||||||
|
|
||||||
## Error Handling
|
|
||||||
|
|
||||||
If reconnection fails:
|
|
||||||
|
|
||||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
|
||||||
2. Fall back to polling the session for final state
|
|
||||||
3. Show appropriate UI message to user
|
|
||||||
|
|
||||||
## Persistence Considerations
|
|
||||||
|
|
||||||
For robust reconnection across browser restarts:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
// Store in localStorage/sessionStorage
|
|
||||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
|
||||||
|
|
||||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
|
||||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
|
||||||
tasks[sessionId] = task;
|
|
||||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
|
||||||
}
|
|
||||||
|
|
||||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
|
||||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Backend Configuration
|
|
||||||
|
|
||||||
The following backend settings affect reconnection behavior:
|
|
||||||
|
|
||||||
| Setting | Default | Description |
|
|
||||||
| ------------------- | ------- | ---------------------------------- |
|
|
||||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
|
||||||
| `stream_max_length` | 1000 | Max messages per stream |
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
To test reconnection locally:
|
|
||||||
|
|
||||||
1. Start a long-running operation (e.g., agent generation)
|
|
||||||
2. Note the `task_id` from the `operation_started` response
|
|
||||||
3. Close the browser tab
|
|
||||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
|
||||||
5. Verify that missed messages are replayed
|
|
||||||
|
|
||||||
See the main README for full local development setup.
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
/**
|
|
||||||
* Constants for the chat system.
|
|
||||||
*
|
|
||||||
* Centralizes magic strings and values used across chat components.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// LocalStorage keys
|
|
||||||
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
|
||||||
|
|
||||||
// Redis Stream IDs
|
|
||||||
export const INITIAL_MESSAGE_ID = "0";
|
|
||||||
export const INITIAL_STREAM_ID = "0-0";
|
|
||||||
|
|
||||||
// TTL values (in milliseconds)
|
|
||||||
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
|
||||||
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
|
||||||
@@ -1,12 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { create } from "zustand";
|
import { create } from "zustand";
|
||||||
import {
|
|
||||||
ACTIVE_TASK_TTL_MS,
|
|
||||||
COMPLETED_STREAM_TTL_MS,
|
|
||||||
INITIAL_STREAM_ID,
|
|
||||||
STORAGE_KEY_ACTIVE_TASKS,
|
|
||||||
} from "./chat-constants";
|
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -14,59 +8,15 @@ import type {
|
|||||||
StreamResult,
|
StreamResult,
|
||||||
StreamStatus,
|
StreamStatus,
|
||||||
} from "./chat-types";
|
} from "./chat-types";
|
||||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
import { executeStream } from "./stream-executor";
|
||||||
|
|
||||||
export interface ActiveTaskInfo {
|
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||||
taskId: string;
|
|
||||||
sessionId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
startedAt: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Load active tasks from localStorage */
|
|
||||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
|
||||||
if (typeof window === "undefined") return new Map();
|
|
||||||
try {
|
|
||||||
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
|
||||||
if (!stored) return new Map();
|
|
||||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
|
||||||
const now = Date.now();
|
|
||||||
const tasks = new Map<string, ActiveTaskInfo>();
|
|
||||||
// Filter out expired tasks
|
|
||||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
|
||||||
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
|
||||||
tasks.set(sessionId, task);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tasks;
|
|
||||||
} catch {
|
|
||||||
return new Map();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Save active tasks to localStorage */
|
|
||||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
|
||||||
if (typeof window === "undefined") return;
|
|
||||||
try {
|
|
||||||
const obj: Record<string, ActiveTaskInfo> = {};
|
|
||||||
for (const [sessionId, task] of tasks) {
|
|
||||||
obj[sessionId] = task;
|
|
||||||
}
|
|
||||||
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
|
||||||
} catch {
|
|
||||||
// Ignore storage errors
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ChatStoreState {
|
interface ChatStoreState {
|
||||||
activeStreams: Map<string, ActiveStream>;
|
activeStreams: Map<string, ActiveStream>;
|
||||||
completedStreams: Map<string, StreamResult>;
|
completedStreams: Map<string, StreamResult>;
|
||||||
activeSessions: Set<string>;
|
activeSessions: Set<string>;
|
||||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
|
||||||
activeTasks: Map<string, ActiveTaskInfo>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ChatStoreActions {
|
interface ChatStoreActions {
|
||||||
@@ -91,24 +41,6 @@ interface ChatStoreActions {
|
|||||||
unregisterActiveSession: (sessionId: string) => void;
|
unregisterActiveSession: (sessionId: string) => void;
|
||||||
isSessionActive: (sessionId: string) => boolean;
|
isSessionActive: (sessionId: string) => boolean;
|
||||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||||
/** Track active task for SSE reconnection */
|
|
||||||
setActiveTask: (
|
|
||||||
sessionId: string,
|
|
||||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
|
||||||
) => void;
|
|
||||||
/** Get active task for a session */
|
|
||||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
|
||||||
/** Clear active task when operation completes */
|
|
||||||
clearActiveTask: (sessionId: string) => void;
|
|
||||||
/** Reconnect to an existing task stream */
|
|
||||||
reconnectToTask: (
|
|
||||||
sessionId: string,
|
|
||||||
taskId: string,
|
|
||||||
lastMessageId?: string,
|
|
||||||
onChunk?: (chunk: StreamChunk) => void,
|
|
||||||
) => Promise<void>;
|
|
||||||
/** Update last message ID for a task (for tracking replay position) */
|
|
||||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||||
@@ -132,126 +64,18 @@ function cleanupExpiredStreams(
|
|||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const cleaned = new Map(completedStreams);
|
const cleaned = new Map(completedStreams);
|
||||||
for (const [sessionId, result] of cleaned) {
|
for (const [sessionId, result] of cleaned) {
|
||||||
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||||
cleaned.delete(sessionId);
|
cleaned.delete(sessionId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cleaned;
|
return cleaned;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Finalize a stream by moving it from activeStreams to completedStreams.
|
|
||||||
* Also handles cleanup and notifications.
|
|
||||||
*/
|
|
||||||
function finalizeStream(
|
|
||||||
sessionId: string,
|
|
||||||
stream: ActiveStream,
|
|
||||||
onChunk: ((chunk: StreamChunk) => void) | undefined,
|
|
||||||
get: () => ChatStoreState & ChatStoreActions,
|
|
||||||
set: (state: Partial<ChatStoreState>) => void,
|
|
||||||
): void {
|
|
||||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
|
||||||
|
|
||||||
if (stream.status !== "streaming") {
|
|
||||||
const currentState = get();
|
|
||||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
|
||||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
|
||||||
|
|
||||||
const storedStream = finalActiveStreams.get(sessionId);
|
|
||||||
if (storedStream === stream) {
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: stream.status,
|
|
||||||
chunks: stream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: stream.error,
|
|
||||||
};
|
|
||||||
finalCompletedStreams.set(sessionId, result);
|
|
||||||
finalActiveStreams.delete(sessionId);
|
|
||||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
|
||||||
set({
|
|
||||||
activeStreams: finalActiveStreams,
|
|
||||||
completedStreams: finalCompletedStreams,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up an existing stream for a session and move it to completed streams.
|
|
||||||
* Returns updated maps for both active and completed streams.
|
|
||||||
*/
|
|
||||||
function cleanupExistingStream(
|
|
||||||
sessionId: string,
|
|
||||||
activeStreams: Map<string, ActiveStream>,
|
|
||||||
completedStreams: Map<string, StreamResult>,
|
|
||||||
callbacks: Set<StreamCompleteCallback>,
|
|
||||||
): {
|
|
||||||
activeStreams: Map<string, ActiveStream>;
|
|
||||||
completedStreams: Map<string, StreamResult>;
|
|
||||||
} {
|
|
||||||
const newActiveStreams = new Map(activeStreams);
|
|
||||||
let newCompletedStreams = new Map(completedStreams);
|
|
||||||
|
|
||||||
const existingStream = newActiveStreams.get(sessionId);
|
|
||||||
if (existingStream) {
|
|
||||||
existingStream.abortController.abort();
|
|
||||||
const normalizedStatus =
|
|
||||||
existingStream.status === "streaming"
|
|
||||||
? "completed"
|
|
||||||
: existingStream.status;
|
|
||||||
const result: StreamResult = {
|
|
||||||
sessionId,
|
|
||||||
status: normalizedStatus,
|
|
||||||
chunks: existingStream.chunks,
|
|
||||||
completedAt: Date.now(),
|
|
||||||
error: existingStream.error,
|
|
||||||
};
|
|
||||||
newCompletedStreams.set(sessionId, result);
|
|
||||||
newActiveStreams.delete(sessionId);
|
|
||||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
|
||||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
|
||||||
notifyStreamComplete(callbacks, sessionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new active stream with initial state.
|
|
||||||
*/
|
|
||||||
function createActiveStream(
|
|
||||||
sessionId: string,
|
|
||||||
onChunk?: (chunk: StreamChunk) => void,
|
|
||||||
): ActiveStream {
|
|
||||||
const abortController = new AbortController();
|
|
||||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
|
||||||
if (onChunk) initialCallbacks.add(onChunk);
|
|
||||||
|
|
||||||
return {
|
|
||||||
sessionId,
|
|
||||||
abortController,
|
|
||||||
status: "streaming",
|
|
||||||
startedAt: Date.now(),
|
|
||||||
chunks: [],
|
|
||||||
onChunkCallbacks: initialCallbacks,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||||
activeStreams: new Map(),
|
activeStreams: new Map(),
|
||||||
completedStreams: new Map(),
|
completedStreams: new Map(),
|
||||||
activeSessions: new Set(),
|
activeSessions: new Set(),
|
||||||
streamCompleteCallbacks: new Set(),
|
streamCompleteCallbacks: new Set(),
|
||||||
activeTasks: loadPersistedTasks(),
|
|
||||||
|
|
||||||
startStream: async function startStream(
|
startStream: async function startStream(
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -261,21 +85,45 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
onChunk,
|
onChunk,
|
||||||
) {
|
) {
|
||||||
const state = get();
|
const state = get();
|
||||||
|
const newActiveStreams = new Map(state.activeStreams);
|
||||||
|
let newCompletedStreams = new Map(state.completedStreams);
|
||||||
const callbacks = state.streamCompleteCallbacks;
|
const callbacks = state.streamCompleteCallbacks;
|
||||||
|
|
||||||
// Clean up any existing stream for this session
|
const existingStream = newActiveStreams.get(sessionId);
|
||||||
const {
|
if (existingStream) {
|
||||||
activeStreams: newActiveStreams,
|
existingStream.abortController.abort();
|
||||||
completedStreams: newCompletedStreams,
|
const normalizedStatus =
|
||||||
} = cleanupExistingStream(
|
existingStream.status === "streaming"
|
||||||
sessionId,
|
? "completed"
|
||||||
state.activeStreams,
|
: existingStream.status;
|
||||||
state.completedStreams,
|
const result: StreamResult = {
|
||||||
callbacks,
|
sessionId,
|
||||||
);
|
status: normalizedStatus,
|
||||||
|
chunks: existingStream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: existingStream.error,
|
||||||
|
};
|
||||||
|
newCompletedStreams.set(sessionId, result);
|
||||||
|
newActiveStreams.delete(sessionId);
|
||||||
|
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||||
|
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||||
|
notifyStreamComplete(callbacks, sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||||
|
if (onChunk) initialCallbacks.add(onChunk);
|
||||||
|
|
||||||
|
const stream: ActiveStream = {
|
||||||
|
sessionId,
|
||||||
|
abortController,
|
||||||
|
status: "streaming",
|
||||||
|
startedAt: Date.now(),
|
||||||
|
chunks: [],
|
||||||
|
onChunkCallbacks: initialCallbacks,
|
||||||
|
};
|
||||||
|
|
||||||
// Create new stream
|
|
||||||
const stream = createActiveStream(sessionId, onChunk);
|
|
||||||
newActiveStreams.set(sessionId, stream);
|
newActiveStreams.set(sessionId, stream);
|
||||||
set({
|
set({
|
||||||
activeStreams: newActiveStreams,
|
activeStreams: newActiveStreams,
|
||||||
@@ -285,7 +133,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
try {
|
try {
|
||||||
await executeStream(stream, message, isUserMessage, context);
|
await executeStream(stream, message, isUserMessage, context);
|
||||||
} finally {
|
} finally {
|
||||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||||
|
if (stream.status !== "streaming") {
|
||||||
|
const currentState = get();
|
||||||
|
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||||
|
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||||
|
|
||||||
|
const storedStream = finalActiveStreams.get(sessionId);
|
||||||
|
if (storedStream === stream) {
|
||||||
|
const result: StreamResult = {
|
||||||
|
sessionId,
|
||||||
|
status: stream.status,
|
||||||
|
chunks: stream.chunks,
|
||||||
|
completedAt: Date.now(),
|
||||||
|
error: stream.error,
|
||||||
|
};
|
||||||
|
finalCompletedStreams.set(sessionId, result);
|
||||||
|
finalActiveStreams.delete(sessionId);
|
||||||
|
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||||
|
set({
|
||||||
|
activeStreams: finalActiveStreams,
|
||||||
|
completedStreams: finalCompletedStreams,
|
||||||
|
});
|
||||||
|
if (stream.status === "completed" || stream.status === "error") {
|
||||||
|
notifyStreamComplete(
|
||||||
|
currentState.streamCompleteCallbacks,
|
||||||
|
sessionId,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -409,93 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
|||||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|
||||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
|
||||||
const state = get();
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.set(sessionId, {
|
|
||||||
...taskInfo,
|
|
||||||
sessionId,
|
|
||||||
startedAt: Date.now(),
|
|
||||||
});
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
|
|
||||||
getActiveTask: function getActiveTask(sessionId) {
|
|
||||||
return get().activeTasks.get(sessionId);
|
|
||||||
},
|
|
||||||
|
|
||||||
clearActiveTask: function clearActiveTask(sessionId) {
|
|
||||||
const state = get();
|
|
||||||
if (!state.activeTasks.has(sessionId)) return;
|
|
||||||
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.delete(sessionId);
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
|
|
||||||
reconnectToTask: async function reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
taskId,
|
|
||||||
lastMessageId = INITIAL_STREAM_ID,
|
|
||||||
onChunk,
|
|
||||||
) {
|
|
||||||
const state = get();
|
|
||||||
const callbacks = state.streamCompleteCallbacks;
|
|
||||||
|
|
||||||
// Clean up any existing stream for this session
|
|
||||||
const {
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
} = cleanupExistingStream(
|
|
||||||
sessionId,
|
|
||||||
state.activeStreams,
|
|
||||||
state.completedStreams,
|
|
||||||
callbacks,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Create new stream for reconnection
|
|
||||||
const stream = createActiveStream(sessionId, onChunk);
|
|
||||||
newActiveStreams.set(sessionId, stream);
|
|
||||||
set({
|
|
||||||
activeStreams: newActiveStreams,
|
|
||||||
completedStreams: newCompletedStreams,
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
|
||||||
} finally {
|
|
||||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
|
||||||
|
|
||||||
// Clear active task on completion
|
|
||||||
if (stream.status === "completed" || stream.status === "error") {
|
|
||||||
const taskState = get();
|
|
||||||
if (taskState.activeTasks.has(sessionId)) {
|
|
||||||
const newActiveTasks = new Map(taskState.activeTasks);
|
|
||||||
newActiveTasks.delete(sessionId);
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
|
||||||
sessionId,
|
|
||||||
lastMessageId,
|
|
||||||
) {
|
|
||||||
const state = get();
|
|
||||||
const task = state.activeTasks.get(sessionId);
|
|
||||||
if (!task) return;
|
|
||||||
|
|
||||||
const newActiveTasks = new Map(state.activeTasks);
|
|
||||||
newActiveTasks.set(sessionId, {
|
|
||||||
...task,
|
|
||||||
lastMessageId,
|
|
||||||
});
|
|
||||||
set({ activeTasks: newActiveTasks });
|
|
||||||
persistTasks(newActiveTasks);
|
|
||||||
},
|
|
||||||
}));
|
}));
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
|||||||
|
|
||||||
export interface StreamChunk {
|
export interface StreamChunk {
|
||||||
type:
|
type:
|
||||||
| "stream_start"
|
|
||||||
| "text_chunk"
|
| "text_chunk"
|
||||||
| "text_ended"
|
| "text_ended"
|
||||||
| "tool_call"
|
| "tool_call"
|
||||||
@@ -16,7 +15,6 @@ export interface StreamChunk {
|
|||||||
| "error"
|
| "error"
|
||||||
| "usage"
|
| "usage"
|
||||||
| "stream_end";
|
| "stream_end";
|
||||||
taskId?: string;
|
|
||||||
timestamp?: string;
|
timestamp?: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
message?: string;
|
message?: string;
|
||||||
@@ -43,7 +41,7 @@ export interface StreamChunk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type VercelStreamChunk =
|
export type VercelStreamChunk =
|
||||||
| { type: "start"; messageId: string; taskId?: string }
|
| { type: "start"; messageId: string }
|
||||||
| { type: "finish" }
|
| { type: "finish" }
|
||||||
| { type: "text-start"; id: string }
|
| { type: "text-start"; id: string }
|
||||||
| { type: "text-delta"; id: string; delta: string }
|
| { type: "text-delta"; id: string; delta: string }
|
||||||
@@ -94,70 +92,3 @@ export interface StreamResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||||
|
|
||||||
// Type guards for message types
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has a toolId property.
|
|
||||||
*/
|
|
||||||
export function hasToolId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { toolId: string } {
|
|
||||||
return (
|
|
||||||
"toolId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).toolId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has an operationId property.
|
|
||||||
*/
|
|
||||||
export function hasOperationId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { operationId: string } {
|
|
||||||
return (
|
|
||||||
"operationId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).operationId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message has a toolCallId property.
|
|
||||||
*/
|
|
||||||
export function hasToolCallId<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & { toolCallId: string } {
|
|
||||||
return (
|
|
||||||
"toolCallId" in msg &&
|
|
||||||
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a message is an operation message type.
|
|
||||||
*/
|
|
||||||
export function isOperationMessage<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): msg is T & {
|
|
||||||
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
|
||||||
} {
|
|
||||||
return (
|
|
||||||
msg.type === "operation_started" ||
|
|
||||||
msg.type === "operation_pending" ||
|
|
||||||
msg.type === "operation_in_progress"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the tool ID from a message if available.
|
|
||||||
* Checks toolId, operationId, and toolCallId properties.
|
|
||||||
*/
|
|
||||||
export function getToolIdFromMessage<T extends { type: string }>(
|
|
||||||
msg: T,
|
|
||||||
): string | undefined {
|
|
||||||
const record = msg as Record<string, unknown>;
|
|
||||||
if (typeof record.toolId === "string") return record.toolId;
|
|
||||||
if (typeof record.operationId === "string") return record.operationId;
|
|
||||||
if (typeof record.toolCallId === "string") return record.toolCallId;
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -17,13 +17,6 @@ export interface ChatContainerProps {
|
|||||||
className?: string;
|
className?: string;
|
||||||
onStreamingChange?: (isStreaming: boolean) => void;
|
onStreamingChange?: (isStreaming: boolean) => void;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
/** Active stream info from the server for reconnection */
|
|
||||||
activeStream?: {
|
|
||||||
taskId: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatContainer({
|
export function ChatContainer({
|
||||||
@@ -33,7 +26,6 @@ export function ChatContainer({
|
|||||||
className,
|
className,
|
||||||
onStreamingChange,
|
onStreamingChange,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
}: ChatContainerProps) {
|
}: ChatContainerProps) {
|
||||||
const {
|
const {
|
||||||
messages,
|
messages,
|
||||||
@@ -49,7 +41,6 @@ export function ChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import { toast } from "sonner";
|
|||||||
import type { StreamChunk } from "../../chat-types";
|
import type { StreamChunk } from "../../chat-types";
|
||||||
import type { HandlerDependencies } from "./handlers";
|
import type { HandlerDependencies } from "./handlers";
|
||||||
import {
|
import {
|
||||||
getErrorDisplayMessage,
|
|
||||||
handleError,
|
handleError,
|
||||||
handleLoginNeeded,
|
handleLoginNeeded,
|
||||||
handleStreamEnd,
|
handleStreamEnd,
|
||||||
@@ -25,22 +24,16 @@ export function createStreamEventDispatcher(
|
|||||||
chunk.type === "need_login" ||
|
chunk.type === "need_login" ||
|
||||||
chunk.type === "error"
|
chunk.type === "error"
|
||||||
) {
|
) {
|
||||||
|
if (!deps.hasResponseRef.current) {
|
||||||
|
console.info("[ChatStream] First response chunk:", {
|
||||||
|
type: chunk.type,
|
||||||
|
sessionId: deps.sessionId,
|
||||||
|
});
|
||||||
|
}
|
||||||
deps.hasResponseRef.current = true;
|
deps.hasResponseRef.current = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "stream_start":
|
|
||||||
// Store task ID for SSE reconnection
|
|
||||||
if (chunk.taskId && deps.onActiveTaskStarted) {
|
|
||||||
deps.onActiveTaskStarted({
|
|
||||||
taskId: chunk.taskId,
|
|
||||||
operationId: chunk.taskId,
|
|
||||||
toolName: "chat",
|
|
||||||
toolCallId: "chat_stream",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "text_chunk":
|
case "text_chunk":
|
||||||
handleTextChunk(chunk, deps);
|
handleTextChunk(chunk, deps);
|
||||||
break;
|
break;
|
||||||
@@ -63,7 +56,11 @@ export function createStreamEventDispatcher(
|
|||||||
break;
|
break;
|
||||||
|
|
||||||
case "stream_end":
|
case "stream_end":
|
||||||
// Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk
|
console.info("[ChatStream] Stream ended:", {
|
||||||
|
sessionId: deps.sessionId,
|
||||||
|
hasResponse: deps.hasResponseRef.current,
|
||||||
|
chunkCount: deps.streamingChunksRef.current.length,
|
||||||
|
});
|
||||||
handleStreamEnd(chunk, deps);
|
handleStreamEnd(chunk, deps);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -73,7 +70,7 @@ export function createStreamEventDispatcher(
|
|||||||
// Show toast at dispatcher level to avoid circular dependencies
|
// Show toast at dispatcher level to avoid circular dependencies
|
||||||
if (!isRegionBlocked) {
|
if (!isRegionBlocked) {
|
||||||
toast.error("Chat Error", {
|
toast.error("Chat Error", {
|
||||||
description: getErrorDisplayMessage(chunk),
|
description: chunk.message || chunk.content || "An error occurred",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -18,19 +18,11 @@ export interface HandlerDependencies {
|
|||||||
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
setStreamingChunks: Dispatch<SetStateAction<string[]>>;
|
||||||
streamingChunksRef: MutableRefObject<string[]>;
|
streamingChunksRef: MutableRefObject<string[]>;
|
||||||
hasResponseRef: MutableRefObject<boolean>;
|
hasResponseRef: MutableRefObject<boolean>;
|
||||||
textFinalizedRef: MutableRefObject<boolean>;
|
|
||||||
streamEndedRef: MutableRefObject<boolean>;
|
|
||||||
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
setMessages: Dispatch<SetStateAction<ChatMessageData[]>>;
|
||||||
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
setIsStreamingInitiated: Dispatch<SetStateAction<boolean>>;
|
||||||
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
setIsRegionBlockedModalOpen: Dispatch<SetStateAction<boolean>>;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
onActiveTaskStarted?: (taskInfo: {
|
|
||||||
taskId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
toolCallId: string;
|
|
||||||
}) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
||||||
@@ -40,25 +32,6 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean {
|
|||||||
return message.toLowerCase().includes("not available in your region");
|
return message.toLowerCase().includes("not available in your region");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getUserFriendlyErrorMessage(
|
|
||||||
code: string | undefined,
|
|
||||||
): string | undefined {
|
|
||||||
switch (code) {
|
|
||||||
case "TASK_EXPIRED":
|
|
||||||
return "This operation has expired. Please try again.";
|
|
||||||
case "TASK_NOT_FOUND":
|
|
||||||
return "Could not find the requested operation.";
|
|
||||||
case "ACCESS_DENIED":
|
|
||||||
return "You do not have access to this operation.";
|
|
||||||
case "QUEUE_OVERFLOW":
|
|
||||||
return "Connection was interrupted. Please refresh to continue.";
|
|
||||||
case "MODEL_NOT_AVAILABLE_REGION":
|
|
||||||
return "This model is not available in your region.";
|
|
||||||
default:
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
if (!chunk.content) return;
|
if (!chunk.content) return;
|
||||||
deps.setHasTextChunks(true);
|
deps.setHasTextChunks(true);
|
||||||
@@ -73,15 +46,10 @@ export function handleTextEnded(
|
|||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
if (deps.textFinalizedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const completedText = deps.streamingChunksRef.current.join("");
|
const completedText = deps.streamingChunksRef.current.join("");
|
||||||
if (completedText.trim()) {
|
if (completedText.trim()) {
|
||||||
deps.textFinalizedRef.current = true;
|
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -108,14 +76,9 @@ export function handleToolCallStart(
|
|||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
// Use deterministic fallback instead of Date.now() to ensure same ID on replay
|
|
||||||
const toolId =
|
|
||||||
chunk.tool_id ||
|
|
||||||
`tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`;
|
|
||||||
|
|
||||||
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
const toolCallMessage: Extract<ChatMessageData, { type: "tool_call" }> = {
|
||||||
type: "tool_call",
|
type: "tool_call",
|
||||||
toolId,
|
toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`,
|
||||||
toolName: chunk.tool_name || "Executing",
|
toolName: chunk.tool_name || "Executing",
|
||||||
arguments: chunk.arguments || {},
|
arguments: chunk.arguments || {},
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
@@ -148,29 +111,6 @@ export function handleToolCallStart(
|
|||||||
deps.setMessages(updateToolCallMessages);
|
deps.setMessages(updateToolCallMessages);
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOOL_RESPONSE_TYPES = new Set([
|
|
||||||
"tool_response",
|
|
||||||
"operation_started",
|
|
||||||
"operation_pending",
|
|
||||||
"operation_in_progress",
|
|
||||||
"execution_started",
|
|
||||||
"agent_carousel",
|
|
||||||
"clarification_needed",
|
|
||||||
]);
|
|
||||||
|
|
||||||
function hasResponseForTool(
|
|
||||||
messages: ChatMessageData[],
|
|
||||||
toolId: string,
|
|
||||||
): boolean {
|
|
||||||
return messages.some((msg) => {
|
|
||||||
if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false;
|
|
||||||
const msgToolId =
|
|
||||||
(msg as { toolId?: string }).toolId ||
|
|
||||||
(msg as { toolCallId?: string }).toolCallId;
|
|
||||||
return msgToolId === toolId;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export function handleToolResponse(
|
export function handleToolResponse(
|
||||||
chunk: StreamChunk,
|
chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
@@ -212,49 +152,31 @@ export function handleToolResponse(
|
|||||||
) {
|
) {
|
||||||
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name);
|
||||||
if (inputsMessage) {
|
if (inputsMessage) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, inputsMessage]);
|
||||||
// Check for duplicate inputs_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "inputs_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, inputsMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
const credentialsMessage = extractCredentialsNeeded(
|
const credentialsMessage = extractCredentialsNeeded(
|
||||||
parsedResult,
|
parsedResult,
|
||||||
chunk.tool_name,
|
chunk.tool_name,
|
||||||
);
|
);
|
||||||
if (credentialsMessage) {
|
if (credentialsMessage) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, credentialsMessage]);
|
||||||
// Check for duplicate credentials_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "credentials_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, credentialsMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Trigger polling when operation_started is received
|
||||||
if (responseMessage.type === "operation_started") {
|
if (responseMessage.type === "operation_started") {
|
||||||
deps.onOperationStarted?.();
|
deps.onOperationStarted?.();
|
||||||
const taskId = (responseMessage as { taskId?: string }).taskId;
|
|
||||||
if (taskId && deps.onActiveTaskStarted) {
|
|
||||||
deps.onActiveTaskStarted({
|
|
||||||
taskId,
|
|
||||||
operationId:
|
|
||||||
(responseMessage as { operationId?: string }).operationId || "",
|
|
||||||
toolName: (responseMessage as { toolName?: string }).toolName || "",
|
|
||||||
toolCallId: (responseMessage as { toolId?: string }).toolId || "",
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
const toolCallIndex = prev.findIndex(
|
const toolCallIndex = prev.findIndex(
|
||||||
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
(msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id,
|
||||||
);
|
);
|
||||||
if (hasResponseForTool(prev, chunk.tool_id!)) {
|
const hasResponse = prev.some(
|
||||||
return prev;
|
(msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id,
|
||||||
}
|
);
|
||||||
|
if (hasResponse) return prev;
|
||||||
if (toolCallIndex !== -1) {
|
if (toolCallIndex !== -1) {
|
||||||
const newMessages = [...prev];
|
const newMessages = [...prev];
|
||||||
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
newMessages.splice(toolCallIndex + 1, 0, responseMessage);
|
||||||
@@ -276,48 +198,28 @@ export function handleLoginNeeded(
|
|||||||
agentInfo: chunk.agent_info,
|
agentInfo: chunk.agent_info,
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
};
|
};
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [...prev, loginNeededMessage]);
|
||||||
// Check for duplicate login_needed message
|
|
||||||
const exists = prev.some((msg) => msg.type === "login_needed");
|
|
||||||
if (exists) return prev;
|
|
||||||
return [...prev, loginNeededMessage];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleStreamEnd(
|
export function handleStreamEnd(
|
||||||
_chunk: StreamChunk,
|
_chunk: StreamChunk,
|
||||||
deps: HandlerDependencies,
|
deps: HandlerDependencies,
|
||||||
) {
|
) {
|
||||||
if (deps.streamEndedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
deps.streamEndedRef.current = true;
|
|
||||||
|
|
||||||
const completedContent = deps.streamingChunksRef.current.join("");
|
const completedContent = deps.streamingChunksRef.current.join("");
|
||||||
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
if (!completedContent.trim() && !deps.hasResponseRef.current) {
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => [
|
||||||
const exists = prev.some(
|
...prev,
|
||||||
(msg) =>
|
{
|
||||||
msg.type === "message" &&
|
type: "message",
|
||||||
msg.role === "assistant" &&
|
role: "assistant",
|
||||||
msg.content === "No response received. Please try again.",
|
content: "No response received. Please try again.",
|
||||||
);
|
timestamp: new Date(),
|
||||||
if (exists) return prev;
|
},
|
||||||
return [
|
]);
|
||||||
...prev,
|
|
||||||
{
|
|
||||||
type: "message",
|
|
||||||
role: "assistant",
|
|
||||||
content: "No response received. Please try again.",
|
|
||||||
timestamp: new Date(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
if (completedContent.trim() && !deps.textFinalizedRef.current) {
|
if (completedContent.trim()) {
|
||||||
deps.textFinalizedRef.current = true;
|
|
||||||
|
|
||||||
deps.setMessages((prev) => {
|
deps.setMessages((prev) => {
|
||||||
|
// Check if this exact message already exists to prevent duplicates
|
||||||
const exists = prev.some(
|
const exists = prev.some(
|
||||||
(msg) =>
|
(msg) =>
|
||||||
msg.type === "message" &&
|
msg.type === "message" &&
|
||||||
@@ -342,6 +244,8 @@ export function handleStreamEnd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
||||||
|
const errorMessage = chunk.message || chunk.content || "An error occurred";
|
||||||
|
console.error("Stream error:", errorMessage);
|
||||||
if (isRegionBlockedError(chunk)) {
|
if (isRegionBlockedError(chunk)) {
|
||||||
deps.setIsRegionBlockedModalOpen(true);
|
deps.setIsRegionBlockedModalOpen(true);
|
||||||
}
|
}
|
||||||
@@ -349,14 +253,4 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) {
|
|||||||
deps.setHasTextChunks(false);
|
deps.setHasTextChunks(false);
|
||||||
deps.setStreamingChunks([]);
|
deps.setStreamingChunks([]);
|
||||||
deps.streamingChunksRef.current = [];
|
deps.streamingChunksRef.current = [];
|
||||||
deps.textFinalizedRef.current = false;
|
|
||||||
deps.streamEndedRef.current = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getErrorDisplayMessage(chunk: StreamChunk): string {
|
|
||||||
const friendlyMessage = getUserFriendlyErrorMessage(chunk.code);
|
|
||||||
if (friendlyMessage) {
|
|
||||||
return friendlyMessage;
|
|
||||||
}
|
|
||||||
return chunk.message || chunk.content || "An error occurred";
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -349,7 +349,6 @@ export function parseToolResponse(
|
|||||||
toolName: (parsedResult.tool_name as string) || toolName,
|
toolName: (parsedResult.tool_name as string) || toolName,
|
||||||
toolId,
|
toolId,
|
||||||
operationId: (parsedResult.operation_id as string) || "",
|
operationId: (parsedResult.operation_id as string) || "",
|
||||||
taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection
|
|
||||||
message:
|
message:
|
||||||
(parsedResult.message as string) ||
|
(parsedResult.message as string) ||
|
||||||
"Operation started. You can close this tab.",
|
"Operation started. You can close this tab.",
|
||||||
|
|||||||
@@ -1,17 +1,10 @@
|
|||||||
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse";
|
||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { INITIAL_STREAM_ID } from "../../chat-constants";
|
|
||||||
import { useChatStore } from "../../chat-store";
|
import { useChatStore } from "../../chat-store";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { useChatStream } from "../../useChatStream";
|
import { useChatStream } from "../../useChatStream";
|
||||||
import { usePageContext } from "../../usePageContext";
|
import { usePageContext } from "../../usePageContext";
|
||||||
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
import type { ChatMessageData } from "../ChatMessage/useChatMessage";
|
||||||
import {
|
|
||||||
getToolIdFromMessage,
|
|
||||||
hasToolId,
|
|
||||||
isOperationMessage,
|
|
||||||
type StreamChunk,
|
|
||||||
} from "../../chat-types";
|
|
||||||
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
import { createStreamEventDispatcher } from "./createStreamEventDispatcher";
|
||||||
import {
|
import {
|
||||||
createUserMessage,
|
createUserMessage,
|
||||||
@@ -21,13 +14,6 @@ import {
|
|||||||
processInitialMessages,
|
processInitialMessages,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
const TOOL_RESULT_TYPES = new Set([
|
|
||||||
"tool_response",
|
|
||||||
"agent_carousel",
|
|
||||||
"execution_started",
|
|
||||||
"clarification_needed",
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Helper to generate deduplication key for a message
|
// Helper to generate deduplication key for a message
|
||||||
function getMessageKey(msg: ChatMessageData): string {
|
function getMessageKey(msg: ChatMessageData): string {
|
||||||
if (msg.type === "message") {
|
if (msg.type === "message") {
|
||||||
@@ -37,18 +23,14 @@ function getMessageKey(msg: ChatMessageData): string {
|
|||||||
return `msg:${msg.role}:${msg.content}`;
|
return `msg:${msg.role}:${msg.content}`;
|
||||||
} else if (msg.type === "tool_call") {
|
} else if (msg.type === "tool_call") {
|
||||||
return `toolcall:${msg.toolId}`;
|
return `toolcall:${msg.toolId}`;
|
||||||
} else if (TOOL_RESULT_TYPES.has(msg.type)) {
|
} else if (msg.type === "tool_response") {
|
||||||
// Unified key for all tool result types - same toolId with different types
|
return `toolresponse:${(msg as any).toolId}`;
|
||||||
// (tool_response vs agent_carousel) should deduplicate to the same key
|
} else if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
// If no toolId, fall back to content-based key to avoid empty key collisions
|
msg.type === "operation_pending" ||
|
||||||
if (!toolId) {
|
msg.type === "operation_in_progress"
|
||||||
return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`;
|
) {
|
||||||
}
|
return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`;
|
||||||
return `toolresult:${toolId}`;
|
|
||||||
} else if (isOperationMessage(msg)) {
|
|
||||||
const toolId = getToolIdFromMessage(msg) || "";
|
|
||||||
return `op:${toolId}:${msg.toolName}`;
|
|
||||||
} else {
|
} else {
|
||||||
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`;
|
||||||
}
|
}
|
||||||
@@ -59,13 +41,6 @@ interface Args {
|
|||||||
initialMessages: SessionDetailResponse["messages"];
|
initialMessages: SessionDetailResponse["messages"];
|
||||||
initialPrompt?: string;
|
initialPrompt?: string;
|
||||||
onOperationStarted?: () => void;
|
onOperationStarted?: () => void;
|
||||||
/** Active stream info from the server for reconnection */
|
|
||||||
activeStream?: {
|
|
||||||
taskId: string;
|
|
||||||
lastMessageId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useChatContainer({
|
export function useChatContainer({
|
||||||
@@ -73,7 +48,6 @@ export function useChatContainer({
|
|||||||
initialMessages,
|
initialMessages,
|
||||||
initialPrompt,
|
initialPrompt,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
activeStream,
|
|
||||||
}: Args) {
|
}: Args) {
|
||||||
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
const [messages, setMessages] = useState<ChatMessageData[]>([]);
|
||||||
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
const [streamingChunks, setStreamingChunks] = useState<string[]>([]);
|
||||||
@@ -83,8 +57,6 @@ export function useChatContainer({
|
|||||||
useState(false);
|
useState(false);
|
||||||
const hasResponseRef = useRef(false);
|
const hasResponseRef = useRef(false);
|
||||||
const streamingChunksRef = useRef<string[]>([]);
|
const streamingChunksRef = useRef<string[]>([]);
|
||||||
const textFinalizedRef = useRef(false);
|
|
||||||
const streamEndedRef = useRef(false);
|
|
||||||
const previousSessionIdRef = useRef<string | null>(null);
|
const previousSessionIdRef = useRef<string | null>(null);
|
||||||
const {
|
const {
|
||||||
error,
|
error,
|
||||||
@@ -93,182 +65,44 @@ export function useChatContainer({
|
|||||||
} = useChatStream();
|
} = useChatStream();
|
||||||
const activeStreams = useChatStore((s) => s.activeStreams);
|
const activeStreams = useChatStore((s) => s.activeStreams);
|
||||||
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
const subscribeToStream = useChatStore((s) => s.subscribeToStream);
|
||||||
const setActiveTask = useChatStore((s) => s.setActiveTask);
|
|
||||||
const getActiveTask = useChatStore((s) => s.getActiveTask);
|
|
||||||
const reconnectToTask = useChatStore((s) => s.reconnectToTask);
|
|
||||||
const isStreaming = isStreamingInitiated || hasTextChunks;
|
const isStreaming = isStreamingInitiated || hasTextChunks;
|
||||||
// Track whether we've already connected to this activeStream to avoid duplicate connections
|
|
||||||
const connectedActiveStreamRef = useRef<string | null>(null);
|
|
||||||
// Track if component is mounted to prevent state updates after unmount
|
|
||||||
const isMountedRef = useRef(true);
|
|
||||||
// Track current dispatcher to prevent multiple dispatchers from adding messages
|
|
||||||
const currentDispatcherIdRef = useRef(0);
|
|
||||||
|
|
||||||
// Set mounted flag - reset on every mount, cleanup on unmount
|
|
||||||
useEffect(function trackMountedState() {
|
|
||||||
isMountedRef.current = true;
|
|
||||||
return function cleanup() {
|
|
||||||
isMountedRef.current = false;
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Callback to store active task info for SSE reconnection
|
|
||||||
function handleActiveTaskStarted(taskInfo: {
|
|
||||||
taskId: string;
|
|
||||||
operationId: string;
|
|
||||||
toolName: string;
|
|
||||||
toolCallId: string;
|
|
||||||
}) {
|
|
||||||
if (!sessionId) return;
|
|
||||||
setActiveTask(sessionId, {
|
|
||||||
taskId: taskInfo.taskId,
|
|
||||||
operationId: taskInfo.operationId,
|
|
||||||
toolName: taskInfo.toolName,
|
|
||||||
lastMessageId: INITIAL_STREAM_ID,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create dispatcher for stream events - stable reference for current sessionId
|
|
||||||
// Each dispatcher gets a unique ID to prevent stale dispatchers from updating state
|
|
||||||
function createDispatcher() {
|
|
||||||
if (!sessionId) return () => {};
|
|
||||||
// Increment dispatcher ID - only the most recent dispatcher should update state
|
|
||||||
const dispatcherId = ++currentDispatcherIdRef.current;
|
|
||||||
|
|
||||||
const baseDispatcher = createStreamEventDispatcher({
|
|
||||||
setHasTextChunks,
|
|
||||||
setStreamingChunks,
|
|
||||||
streamingChunksRef,
|
|
||||||
hasResponseRef,
|
|
||||||
textFinalizedRef,
|
|
||||||
streamEndedRef,
|
|
||||||
setMessages,
|
|
||||||
setIsRegionBlockedModalOpen,
|
|
||||||
sessionId,
|
|
||||||
setIsStreamingInitiated,
|
|
||||||
onOperationStarted,
|
|
||||||
onActiveTaskStarted: handleActiveTaskStarted,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wrap dispatcher to check if it's still the current one
|
|
||||||
return function guardedDispatcher(chunk: StreamChunk) {
|
|
||||||
// Skip if component unmounted or this is a stale dispatcher
|
|
||||||
if (!isMountedRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (dispatcherId !== currentDispatcherIdRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
baseDispatcher(chunk);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
useEffect(
|
useEffect(
|
||||||
function handleSessionChange() {
|
function handleSessionChange() {
|
||||||
const isSessionChange = sessionId !== previousSessionIdRef.current;
|
if (sessionId === previousSessionIdRef.current) return;
|
||||||
|
|
||||||
// Handle session change - reset state
|
const prevSession = previousSessionIdRef.current;
|
||||||
if (isSessionChange) {
|
if (prevSession) {
|
||||||
const prevSession = previousSessionIdRef.current;
|
stopStreaming(prevSession);
|
||||||
if (prevSession) {
|
|
||||||
stopStreaming(prevSession);
|
|
||||||
}
|
|
||||||
previousSessionIdRef.current = sessionId;
|
|
||||||
connectedActiveStreamRef.current = null;
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
setIsStreamingInitiated(false);
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
}
|
}
|
||||||
|
previousSessionIdRef.current = sessionId;
|
||||||
|
setMessages([]);
|
||||||
|
setStreamingChunks([]);
|
||||||
|
streamingChunksRef.current = [];
|
||||||
|
setHasTextChunks(false);
|
||||||
|
setIsStreamingInitiated(false);
|
||||||
|
hasResponseRef.current = false;
|
||||||
|
|
||||||
if (!sessionId) return;
|
if (!sessionId) return;
|
||||||
|
|
||||||
// Priority 1: Check if server told us there's an active stream (most authoritative)
|
const activeStream = activeStreams.get(sessionId);
|
||||||
if (activeStream) {
|
if (!activeStream || activeStream.status !== "streaming") return;
|
||||||
const streamKey = `${sessionId}:${activeStream.taskId}`;
|
|
||||||
|
|
||||||
if (connectedActiveStreamRef.current === streamKey) {
|
const dispatcher = createStreamEventDispatcher({
|
||||||
return;
|
setHasTextChunks,
|
||||||
}
|
setStreamingChunks,
|
||||||
|
streamingChunksRef,
|
||||||
// Skip if there's already an active stream for this session in the store
|
hasResponseRef,
|
||||||
const existingStream = activeStreams.get(sessionId);
|
setMessages,
|
||||||
if (existingStream && existingStream.status === "streaming") {
|
setIsRegionBlockedModalOpen,
|
||||||
connectedActiveStreamRef.current = streamKey;
|
sessionId,
|
||||||
return;
|
setIsStreamingInitiated,
|
||||||
}
|
onOperationStarted,
|
||||||
|
});
|
||||||
connectedActiveStreamRef.current = streamKey;
|
|
||||||
|
|
||||||
// Clear all state before reconnection to prevent duplicates
|
|
||||||
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
|
||||||
setActiveTask(sessionId, {
|
|
||||||
taskId: activeStream.taskId,
|
|
||||||
operationId: activeStream.operationId,
|
|
||||||
toolName: activeStream.toolName,
|
|
||||||
lastMessageId: activeStream.lastMessageId,
|
|
||||||
});
|
|
||||||
reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeStream.taskId,
|
|
||||||
activeStream.lastMessageId,
|
|
||||||
createDispatcher(),
|
|
||||||
);
|
|
||||||
// Don't return cleanup here - the guarded dispatcher handles stale events
|
|
||||||
// and the stream will complete naturally. Cleanup would prematurely stop
|
|
||||||
// the stream when effect re-runs due to activeStreams changing.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only check localStorage/in-memory on session change
|
|
||||||
if (!isSessionChange) return;
|
|
||||||
|
|
||||||
// Priority 2: Check localStorage for active task
|
|
||||||
const activeTask = getActiveTask(sessionId);
|
|
||||||
if (activeTask) {
|
|
||||||
// Clear all state before reconnection to prevent duplicates
|
|
||||||
// Server's initialMessages is authoritative; local state will be rebuilt from SSE replay
|
|
||||||
setMessages([]);
|
|
||||||
setStreamingChunks([]);
|
|
||||||
streamingChunksRef.current = [];
|
|
||||||
setHasTextChunks(false);
|
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
|
||||||
hasResponseRef.current = false;
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
|
||||||
reconnectToTask(
|
|
||||||
sessionId,
|
|
||||||
activeTask.taskId,
|
|
||||||
activeTask.lastMessageId,
|
|
||||||
createDispatcher(),
|
|
||||||
);
|
|
||||||
// Don't return cleanup here - the guarded dispatcher handles stale events
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 3: Check for an in-memory active stream (same-tab scenario)
|
|
||||||
const inMemoryStream = activeStreams.get(sessionId);
|
|
||||||
if (!inMemoryStream || inMemoryStream.status !== "streaming") {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
const skipReplay = initialMessages.length > 0;
|
const skipReplay = initialMessages.length > 0;
|
||||||
return subscribeToStream(sessionId, createDispatcher(), skipReplay);
|
return subscribeToStream(sessionId, dispatcher, skipReplay);
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -276,10 +110,6 @@ export function useChatContainer({
|
|||||||
activeStreams,
|
activeStreams,
|
||||||
subscribeToStream,
|
subscribeToStream,
|
||||||
onOperationStarted,
|
onOperationStarted,
|
||||||
getActiveTask,
|
|
||||||
reconnectToTask,
|
|
||||||
activeStream,
|
|
||||||
setActiveTask,
|
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -294,7 +124,7 @@ export function useChatContainer({
|
|||||||
msg.type === "agent_carousel" ||
|
msg.type === "agent_carousel" ||
|
||||||
msg.type === "execution_started"
|
msg.type === "execution_started"
|
||||||
) {
|
) {
|
||||||
const toolId = hasToolId(msg) ? msg.toolId : undefined;
|
const toolId = (msg as any).toolId;
|
||||||
if (toolId) {
|
if (toolId) {
|
||||||
ids.add(toolId);
|
ids.add(toolId);
|
||||||
}
|
}
|
||||||
@@ -311,8 +141,12 @@ export function useChatContainer({
|
|||||||
|
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const filtered = prev.filter((msg) => {
|
const filtered = prev.filter((msg) => {
|
||||||
if (isOperationMessage(msg)) {
|
if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false; // Remove - operation completed
|
return false; // Remove - operation completed
|
||||||
}
|
}
|
||||||
@@ -340,8 +174,12 @@ export function useChatContainer({
|
|||||||
// Filter local messages: remove duplicates and completed operation messages
|
// Filter local messages: remove duplicates and completed operation messages
|
||||||
const newLocalMessages = messages.filter((msg) => {
|
const newLocalMessages = messages.filter((msg) => {
|
||||||
// Remove operation messages for completed tools
|
// Remove operation messages for completed tools
|
||||||
if (isOperationMessage(msg)) {
|
if (
|
||||||
const toolId = getToolIdFromMessage(msg);
|
msg.type === "operation_started" ||
|
||||||
|
msg.type === "operation_pending" ||
|
||||||
|
msg.type === "operation_in_progress"
|
||||||
|
) {
|
||||||
|
const toolId = (msg as any).toolId || (msg as any).toolCallId;
|
||||||
if (toolId && completedToolIds.has(toolId)) {
|
if (toolId && completedToolIds.has(toolId)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -352,70 +190,7 @@ export function useChatContainer({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Server messages first (correct order), then new local messages
|
// Server messages first (correct order), then new local messages
|
||||||
const combined = [...processedInitial, ...newLocalMessages];
|
return [...processedInitial, ...newLocalMessages];
|
||||||
|
|
||||||
// Post-processing: Remove duplicate assistant messages that can occur during
|
|
||||||
// race conditions (e.g., rapid screen switching during SSE reconnection).
|
|
||||||
// Two assistant messages are considered duplicates if:
|
|
||||||
// - They are both text messages with role "assistant"
|
|
||||||
// - One message's content starts with the other's content (partial vs complete)
|
|
||||||
// - Or they have very similar content (>80% overlap at the start)
|
|
||||||
const deduplicated: ChatMessageData[] = [];
|
|
||||||
for (let i = 0; i < combined.length; i++) {
|
|
||||||
const current = combined[i];
|
|
||||||
|
|
||||||
// Check if this is an assistant text message
|
|
||||||
if (current.type !== "message" || current.role !== "assistant") {
|
|
||||||
deduplicated.push(current);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for duplicate assistant messages in the rest of the array
|
|
||||||
let dominated = false;
|
|
||||||
for (let j = 0; j < combined.length; j++) {
|
|
||||||
if (i === j) continue;
|
|
||||||
const other = combined[j];
|
|
||||||
if (other.type !== "message" || other.role !== "assistant") continue;
|
|
||||||
|
|
||||||
const currentContent = current.content || "";
|
|
||||||
const otherContent = other.content || "";
|
|
||||||
|
|
||||||
// Skip empty messages
|
|
||||||
if (!currentContent.trim() || !otherContent.trim()) continue;
|
|
||||||
|
|
||||||
// Check if current is a prefix of other (current is incomplete version)
|
|
||||||
if (
|
|
||||||
otherContent.length > currentContent.length &&
|
|
||||||
otherContent.startsWith(currentContent.slice(0, 100))
|
|
||||||
) {
|
|
||||||
// Current is a shorter/incomplete version of other - skip it
|
|
||||||
dominated = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if messages are nearly identical (within a small difference)
|
|
||||||
// This catches cases where content differs only slightly
|
|
||||||
const minLen = Math.min(currentContent.length, otherContent.length);
|
|
||||||
const compareLen = Math.min(minLen, 200); // Compare first 200 chars
|
|
||||||
if (
|
|
||||||
compareLen > 50 &&
|
|
||||||
currentContent.slice(0, compareLen) ===
|
|
||||||
otherContent.slice(0, compareLen)
|
|
||||||
) {
|
|
||||||
// Same prefix - keep the longer one
|
|
||||||
if (otherContent.length > currentContent.length) {
|
|
||||||
dominated = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!dominated) {
|
|
||||||
deduplicated.push(current);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return deduplicated;
|
|
||||||
}, [initialMessages, messages, completedToolIds]);
|
}, [initialMessages, messages, completedToolIds]);
|
||||||
|
|
||||||
async function sendMessage(
|
async function sendMessage(
|
||||||
@@ -423,8 +198,10 @@ export function useChatContainer({
|
|||||||
isUserMessage: boolean = true,
|
isUserMessage: boolean = true,
|
||||||
context?: { url: string; content: string },
|
context?: { url: string; content: string },
|
||||||
) {
|
) {
|
||||||
if (!sessionId) return;
|
if (!sessionId) {
|
||||||
|
console.error("[useChatContainer] Cannot send message: no session ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
setIsRegionBlockedModalOpen(false);
|
setIsRegionBlockedModalOpen(false);
|
||||||
if (isUserMessage) {
|
if (isUserMessage) {
|
||||||
const userMessage = createUserMessage(content);
|
const userMessage = createUserMessage(content);
|
||||||
@@ -437,19 +214,31 @@ export function useChatContainer({
|
|||||||
setHasTextChunks(false);
|
setHasTextChunks(false);
|
||||||
setIsStreamingInitiated(true);
|
setIsStreamingInitiated(true);
|
||||||
hasResponseRef.current = false;
|
hasResponseRef.current = false;
|
||||||
textFinalizedRef.current = false;
|
|
||||||
streamEndedRef.current = false;
|
const dispatcher = createStreamEventDispatcher({
|
||||||
|
setHasTextChunks,
|
||||||
|
setStreamingChunks,
|
||||||
|
streamingChunksRef,
|
||||||
|
hasResponseRef,
|
||||||
|
setMessages,
|
||||||
|
setIsRegionBlockedModalOpen,
|
||||||
|
sessionId,
|
||||||
|
setIsStreamingInitiated,
|
||||||
|
onOperationStarted,
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await sendStreamMessage(
|
await sendStreamMessage(
|
||||||
sessionId,
|
sessionId,
|
||||||
content,
|
content,
|
||||||
createDispatcher(),
|
dispatcher,
|
||||||
isUserMessage,
|
isUserMessage,
|
||||||
context,
|
context,
|
||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
console.error("[useChatContainer] Failed to send message:", err);
|
||||||
setIsStreamingInitiated(false);
|
setIsStreamingInitiated(false);
|
||||||
|
|
||||||
if (err instanceof Error && err.name === "AbortError") return;
|
if (err instanceof Error && err.name === "AbortError") return;
|
||||||
|
|
||||||
const errorMessage =
|
const errorMessage =
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ export type ChatMessageData =
|
|||||||
toolName: string;
|
toolName: string;
|
||||||
toolId: string;
|
toolId: string;
|
||||||
operationId: string;
|
operationId: string;
|
||||||
taskId?: string; // For SSE reconnection
|
|
||||||
message: string;
|
message: string;
|
||||||
timestamp?: string | Date;
|
timestamp?: string | Date;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ export function MessageList({
|
|||||||
isStreaming,
|
isStreaming,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Keeps this for debugging purposes 💆🏽
|
||||||
|
*/
|
||||||
|
console.log(messages);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative flex min-h-0 flex-1 flex-col">
|
<div className="relative flex min-h-0 flex-1 flex-col">
|
||||||
{/* Top fade shadow */}
|
{/* Top fade shadow */}
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import { INITIAL_STREAM_ID } from "./chat-constants";
|
|
||||||
import type {
|
import type {
|
||||||
ActiveStream,
|
ActiveStream,
|
||||||
StreamChunk,
|
StreamChunk,
|
||||||
@@ -11,14 +10,8 @@ import {
|
|||||||
parseSSELine,
|
parseSSELine,
|
||||||
} from "./stream-utils";
|
} from "./stream-utils";
|
||||||
|
|
||||||
function notifySubscribers(
|
function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) {
|
||||||
stream: ActiveStream,
|
stream.chunks.push(chunk);
|
||||||
chunk: StreamChunk,
|
|
||||||
skipStore = false,
|
|
||||||
) {
|
|
||||||
if (!skipStore) {
|
|
||||||
stream.chunks.push(chunk);
|
|
||||||
}
|
|
||||||
for (const callback of stream.onChunkCallbacks) {
|
for (const callback of stream.onChunkCallbacks) {
|
||||||
try {
|
try {
|
||||||
callback(chunk);
|
callback(chunk);
|
||||||
@@ -28,114 +21,36 @@ function notifySubscribers(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
interface StreamExecutionOptions {
|
export async function executeStream(
|
||||||
stream: ActiveStream;
|
stream: ActiveStream,
|
||||||
mode: "new" | "reconnect";
|
message: string,
|
||||||
message?: string;
|
isUserMessage: boolean,
|
||||||
isUserMessage?: boolean;
|
context?: { url: string; content: string },
|
||||||
context?: { url: string; content: string };
|
retryCount: number = 0,
|
||||||
taskId?: string;
|
|
||||||
lastMessageId?: string;
|
|
||||||
retryCount?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function executeStreamInternal(
|
|
||||||
options: StreamExecutionOptions,
|
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const {
|
|
||||||
stream,
|
|
||||||
mode,
|
|
||||||
message,
|
|
||||||
isUserMessage,
|
|
||||||
context,
|
|
||||||
taskId,
|
|
||||||
lastMessageId = INITIAL_STREAM_ID,
|
|
||||||
retryCount = 0,
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
const { sessionId, abortController } = stream;
|
const { sessionId, abortController } = stream;
|
||||||
const isReconnect = mode === "reconnect";
|
|
||||||
|
|
||||||
if (isReconnect) {
|
|
||||||
if (!taskId) {
|
|
||||||
throw new Error("taskId is required for reconnect mode");
|
|
||||||
}
|
|
||||||
if (lastMessageId === null || lastMessageId === undefined) {
|
|
||||||
throw new Error("lastMessageId is required for reconnect mode");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!message) {
|
|
||||||
throw new Error("message is required for new stream mode");
|
|
||||||
}
|
|
||||||
if (isUserMessage === undefined) {
|
|
||||||
throw new Error("isUserMessage is required for new stream mode");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let url: string;
|
const url = `/api/chat/sessions/${sessionId}/stream`;
|
||||||
let fetchOptions: RequestInit;
|
const body = JSON.stringify({
|
||||||
|
message,
|
||||||
|
is_user_message: isUserMessage,
|
||||||
|
context: context || null,
|
||||||
|
});
|
||||||
|
|
||||||
if (isReconnect) {
|
const response = await fetch(url, {
|
||||||
url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`;
|
method: "POST",
|
||||||
fetchOptions = {
|
headers: {
|
||||||
method: "GET",
|
"Content-Type": "application/json",
|
||||||
headers: {
|
Accept: "text/event-stream",
|
||||||
Accept: "text/event-stream",
|
},
|
||||||
},
|
body,
|
||||||
signal: abortController.signal,
|
signal: abortController.signal,
|
||||||
};
|
});
|
||||||
} else {
|
|
||||||
url = `/api/chat/sessions/${sessionId}/stream`;
|
|
||||||
fetchOptions = {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
Accept: "text/event-stream",
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
message,
|
|
||||||
is_user_message: isUserMessage,
|
|
||||||
context: context || null,
|
|
||||||
}),
|
|
||||||
signal: abortController.signal,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch(url, fetchOptions);
|
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorText = await response.text();
|
const errorText = await response.text();
|
||||||
let errorCode: string | undefined;
|
throw new Error(errorText || `HTTP ${response.status}`);
|
||||||
let errorMessage = errorText || `HTTP ${response.status}`;
|
|
||||||
try {
|
|
||||||
const parsed = JSON.parse(errorText);
|
|
||||||
if (parsed.detail) {
|
|
||||||
const detail =
|
|
||||||
typeof parsed.detail === "string"
|
|
||||||
? parsed.detail
|
|
||||||
: parsed.detail.message || JSON.stringify(parsed.detail);
|
|
||||||
errorMessage = detail;
|
|
||||||
errorCode =
|
|
||||||
typeof parsed.detail === "object" ? parsed.detail.code : undefined;
|
|
||||||
}
|
|
||||||
} catch {}
|
|
||||||
|
|
||||||
const isPermanentError =
|
|
||||||
isReconnect &&
|
|
||||||
(response.status === 404 ||
|
|
||||||
response.status === 403 ||
|
|
||||||
response.status === 410);
|
|
||||||
|
|
||||||
const error = new Error(errorMessage) as Error & {
|
|
||||||
status?: number;
|
|
||||||
isPermanent?: boolean;
|
|
||||||
taskErrorCode?: string;
|
|
||||||
};
|
|
||||||
error.status = response.status;
|
|
||||||
error.isPermanent = isPermanentError;
|
|
||||||
error.taskErrorCode = errorCode;
|
|
||||||
throw error;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!response.body) {
|
if (!response.body) {
|
||||||
@@ -189,7 +104,9 @@ async function executeStreamInternal(
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} catch {}
|
} catch (err) {
|
||||||
|
console.warn("[StreamExecutor] Failed to parse SSE chunk:", err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -200,17 +117,19 @@ async function executeStreamInternal(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const isPermanentError =
|
if (retryCount < MAX_RETRIES) {
|
||||||
err instanceof Error &&
|
|
||||||
(err as Error & { isPermanent?: boolean }).isPermanent;
|
|
||||||
|
|
||||||
if (!isPermanentError && retryCount < MAX_RETRIES) {
|
|
||||||
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount);
|
||||||
|
console.log(
|
||||||
|
`[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`,
|
||||||
|
);
|
||||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||||
return executeStreamInternal({
|
return executeStream(
|
||||||
...options,
|
stream,
|
||||||
retryCount: retryCount + 1,
|
message,
|
||||||
});
|
isUserMessage,
|
||||||
|
context,
|
||||||
|
retryCount + 1,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
stream.status = "error";
|
stream.status = "error";
|
||||||
@@ -221,35 +140,3 @@ async function executeStreamInternal(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function executeStream(
|
|
||||||
stream: ActiveStream,
|
|
||||||
message: string,
|
|
||||||
isUserMessage: boolean,
|
|
||||||
context?: { url: string; content: string },
|
|
||||||
retryCount: number = 0,
|
|
||||||
): Promise<void> {
|
|
||||||
return executeStreamInternal({
|
|
||||||
stream,
|
|
||||||
mode: "new",
|
|
||||||
message,
|
|
||||||
isUserMessage,
|
|
||||||
context,
|
|
||||||
retryCount,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function executeTaskReconnect(
|
|
||||||
stream: ActiveStream,
|
|
||||||
taskId: string,
|
|
||||||
lastMessageId: string = INITIAL_STREAM_ID,
|
|
||||||
retryCount: number = 0,
|
|
||||||
): Promise<void> {
|
|
||||||
return executeStreamInternal({
|
|
||||||
stream,
|
|
||||||
mode: "reconnect",
|
|
||||||
taskId,
|
|
||||||
lastMessageId,
|
|
||||||
retryCount,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ export function normalizeStreamChunk(
|
|||||||
|
|
||||||
switch (chunk.type) {
|
switch (chunk.type) {
|
||||||
case "text-delta":
|
case "text-delta":
|
||||||
// Vercel AI SDK sends "delta" for text content
|
|
||||||
return { type: "text_chunk", content: chunk.delta };
|
return { type: "text_chunk", content: chunk.delta };
|
||||||
case "text-end":
|
case "text-end":
|
||||||
return { type: "text_ended" };
|
return { type: "text_ended" };
|
||||||
@@ -64,10 +63,6 @@ export function normalizeStreamChunk(
|
|||||||
case "finish":
|
case "finish":
|
||||||
return { type: "stream_end" };
|
return { type: "stream_end" };
|
||||||
case "start":
|
case "start":
|
||||||
// Start event with optional taskId for reconnection
|
|
||||||
return chunk.taskId
|
|
||||||
? { type: "stream_start", taskId: chunk.taskId }
|
|
||||||
: null;
|
|
||||||
case "text-start":
|
case "text-start":
|
||||||
return null;
|
return null;
|
||||||
case "tool-input-start":
|
case "tool-input-start":
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
||||||
@@ -23,11 +24,11 @@ interface Props {
|
|||||||
export function NavbarLink({ name, href }: Props) {
|
export function NavbarLink({ name, href }: Props) {
|
||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
const expectedHomeRoute = isChatEnabled ? "/copilot" : "/library";
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
const isActive =
|
const isActive =
|
||||||
href === expectedHomeRoute
|
href === homepageRoute
|
||||||
? pathname === "/" || pathname.startsWith(expectedHomeRoute)
|
? pathname === "/" || pathname.startsWith(homepageRoute)
|
||||||
: pathname.includes(href);
|
: pathname.includes(href);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ export default function useAgentGraph(
|
|||||||
>(null);
|
>(null);
|
||||||
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
||||||
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
||||||
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS) as string[];
|
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS);
|
||||||
|
|
||||||
// Filter blocks based on beta flags
|
// Filter blocks based on beta flags
|
||||||
const availableBlocks = useMemo(() => {
|
const availableBlocks = useMemo(() => {
|
||||||
|
|||||||
@@ -11,3 +11,10 @@ export const API_KEY_HEADER_NAME = "X-API-Key";
|
|||||||
|
|
||||||
// Layout
|
// Layout
|
||||||
export const NAVBAR_HEIGHT_PX = 60;
|
export const NAVBAR_HEIGHT_PX = 60;
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
export function getHomepageRoute(isChatEnabled?: boolean | null): string {
|
||||||
|
if (isChatEnabled === true) return "/copilot";
|
||||||
|
if (isChatEnabled === false) return "/library";
|
||||||
|
return "/";
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { Key, storage } from "@/services/storage/local-storage";
|
import { Key, storage } from "@/services/storage/local-storage";
|
||||||
import { type CookieOptions } from "@supabase/ssr";
|
import { type CookieOptions } from "@supabase/ssr";
|
||||||
@@ -70,7 +71,7 @@ export function getRedirectPath(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isAdminPage(path) && userRole !== "admin") {
|
if (isAdminPage(path) && userRole !== "admin") {
|
||||||
return "/";
|
return getHomepageRoute();
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { environment } from "@/services/environment";
|
import { environment } from "@/services/environment";
|
||||||
import { createServerClient } from "@supabase/ssr";
|
import { createServerClient } from "@supabase/ssr";
|
||||||
import { NextResponse, type NextRequest } from "next/server";
|
import { NextResponse, type NextRequest } from "next/server";
|
||||||
@@ -66,7 +67,7 @@ export async function updateSession(request: NextRequest) {
|
|||||||
|
|
||||||
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
||||||
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
||||||
url.pathname = "/";
|
url.pathname = getHomepageRoute();
|
||||||
return NextResponse.redirect(url);
|
return NextResponse.redirect(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ import {
|
|||||||
WebSocketNotification,
|
WebSocketNotification,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
import { getHomepageRoute } from "@/lib/constants";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname, useRouter } from "next/navigation";
|
import { usePathname, useRouter } from "next/navigation";
|
||||||
import {
|
import {
|
||||||
@@ -102,6 +104,8 @@ export default function OnboardingProvider({
|
|||||||
const pathname = usePathname();
|
const pathname = usePathname();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const { isLoggedIn } = useSupabase();
|
const { isLoggedIn } = useSupabase();
|
||||||
|
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||||
|
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||||
|
|
||||||
useOnboardingTimezoneDetection();
|
useOnboardingTimezoneDetection();
|
||||||
|
|
||||||
@@ -146,7 +150,7 @@ export default function OnboardingProvider({
|
|||||||
if (isOnOnboardingRoute) {
|
if (isOnOnboardingRoute) {
|
||||||
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
||||||
if (!enabled) {
|
if (!enabled) {
|
||||||
router.push("/");
|
router.push(homepageRoute);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,7 +162,7 @@ export default function OnboardingProvider({
|
|||||||
isOnOnboardingRoute &&
|
isOnOnboardingRoute &&
|
||||||
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
||||||
) {
|
) {
|
||||||
router.push("/");
|
router.push(homepageRoute);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Failed to initialize onboarding:", error);
|
console.error("Failed to initialize onboarding:", error);
|
||||||
@@ -173,7 +177,7 @@ export default function OnboardingProvider({
|
|||||||
}
|
}
|
||||||
|
|
||||||
initializeOnboarding();
|
initializeOnboarding();
|
||||||
}, [api, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
}, [api, homepageRoute, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
||||||
|
|
||||||
const handleOnboardingNotification = useCallback(
|
const handleOnboardingNotification = useCallback(
|
||||||
(notification: WebSocketNotification) => {
|
(notification: WebSocketNotification) => {
|
||||||
|
|||||||
@@ -83,10 +83,6 @@ function getPostHogCredentials() {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function getLaunchDarklyClientId() {
|
|
||||||
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
|
||||||
}
|
|
||||||
|
|
||||||
function isProductionBuild() {
|
function isProductionBuild() {
|
||||||
return process.env.NODE_ENV === "production";
|
return process.env.NODE_ENV === "production";
|
||||||
}
|
}
|
||||||
@@ -124,10 +120,7 @@ function isVercelPreview() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function areFeatureFlagsEnabled() {
|
function areFeatureFlagsEnabled() {
|
||||||
return (
|
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled";
|
||||||
process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true" &&
|
|
||||||
Boolean(process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function isPostHogEnabled() {
|
function isPostHogEnabled() {
|
||||||
@@ -150,7 +143,6 @@ export const environment = {
|
|||||||
getSupabaseAnonKey,
|
getSupabaseAnonKey,
|
||||||
getPreviewStealingDev,
|
getPreviewStealingDev,
|
||||||
getPostHogCredentials,
|
getPostHogCredentials,
|
||||||
getLaunchDarklyClientId,
|
|
||||||
// Assertions
|
// Assertions
|
||||||
isServerSide,
|
isServerSide,
|
||||||
isClientSide,
|
isClientSide,
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { ReactNode, useEffect, useState } from "react";
|
|
||||||
import { environment } from "../environment";
|
|
||||||
import { Flag, useGetFlag } from "./use-get-flag";
|
|
||||||
|
|
||||||
interface FeatureFlagRedirectProps {
|
|
||||||
flag: Flag;
|
|
||||||
whenDisabled: string;
|
|
||||||
children: ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function FeatureFlagPage({
|
|
||||||
flag,
|
|
||||||
whenDisabled,
|
|
||||||
children,
|
|
||||||
}: FeatureFlagRedirectProps) {
|
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
|
||||||
const router = useRouter();
|
|
||||||
const flagValue = useGetFlag(flag);
|
|
||||||
const ldClient = useLDClient();
|
|
||||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
const ldReady = Boolean(ldClient);
|
|
||||||
const flagEnabled = Boolean(flagValue);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const initialize = async () => {
|
|
||||||
if (!ldEnabled) {
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
|
||||||
if (ldEnabled && !ldReady) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
await ldClient?.waitForInitialization();
|
|
||||||
if (!flagEnabled) router.replace(whenDisabled);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
} finally {
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
initialize();
|
|
||||||
}, [ldReady, flagEnabled]);
|
|
||||||
|
|
||||||
return isLoading || !flagEnabled ? (
|
|
||||||
<LoadingSpinner size="large" cover />
|
|
||||||
) : (
|
|
||||||
<>{children}</>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
|
||||||
import { useRouter } from "next/navigation";
|
|
||||||
import { useEffect } from "react";
|
|
||||||
import { environment } from "../environment";
|
|
||||||
import { Flag, useGetFlag } from "./use-get-flag";
|
|
||||||
|
|
||||||
interface FeatureFlagRedirectProps {
|
|
||||||
flag: Flag;
|
|
||||||
whenEnabled: string;
|
|
||||||
whenDisabled: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function FeatureFlagRedirect({
|
|
||||||
flag,
|
|
||||||
whenEnabled,
|
|
||||||
whenDisabled,
|
|
||||||
}: FeatureFlagRedirectProps) {
|
|
||||||
const router = useRouter();
|
|
||||||
const flagValue = useGetFlag(flag);
|
|
||||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
const ldClient = useLDClient();
|
|
||||||
const ldReady = Boolean(ldClient);
|
|
||||||
const flagEnabled = Boolean(flagValue);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const initialize = async () => {
|
|
||||||
if (!ldEnabled) {
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
|
||||||
if (ldEnabled && !ldReady) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
await ldClient?.waitForInitialization();
|
|
||||||
router.replace(flagEnabled ? whenEnabled : whenDisabled);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
router.replace(whenDisabled);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
initialize();
|
|
||||||
}, [ldReady, flagEnabled]);
|
|
||||||
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import * as Sentry from "@sentry/nextjs";
|
import * as Sentry from "@sentry/nextjs";
|
||||||
import { LDProvider } from "launchdarkly-react-client-sdk";
|
import { LDProvider } from "launchdarkly-react-client-sdk";
|
||||||
@@ -8,17 +7,17 @@ import type { ReactNode } from "react";
|
|||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { environment } from "../environment";
|
import { environment } from "../environment";
|
||||||
|
|
||||||
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
||||||
|
|
||||||
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||||
const { user, isUserLoading } = useSupabase();
|
const { user, isUserLoading } = useSupabase();
|
||||||
const envEnabled = environment.areFeatureFlagsEnabled();
|
const isCloud = environment.isCloud();
|
||||||
const clientId = environment.getLaunchDarklyClientId();
|
const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
|
||||||
|
|
||||||
const context = useMemo(() => {
|
const context = useMemo(() => {
|
||||||
if (isUserLoading) return;
|
if (isUserLoading || !user) {
|
||||||
|
|
||||||
if (!user) {
|
|
||||||
return {
|
return {
|
||||||
kind: "user" as const,
|
kind: "user" as const,
|
||||||
key: "anonymous",
|
key: "anonymous",
|
||||||
@@ -37,17 +36,15 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
|||||||
};
|
};
|
||||||
}, [user, isUserLoading]);
|
}, [user, isUserLoading]);
|
||||||
|
|
||||||
if (!envEnabled) {
|
if (!isLaunchDarklyConfigured) {
|
||||||
return <>{children}</>;
|
return <>{children}</>;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isUserLoading) {
|
|
||||||
return <LoadingSpinner size="large" cover />;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<LDProvider
|
<LDProvider
|
||||||
clientSideID={clientId ?? ""}
|
// Add this key prop. It will be 'anonymous' when logged out,
|
||||||
|
key={context.key}
|
||||||
|
clientSideID={clientId}
|
||||||
context={context}
|
context={context}
|
||||||
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
||||||
reactOptions={{ useCamelCaseFlagKeys: false }}
|
reactOptions={{ useCamelCaseFlagKeys: false }}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
||||||
import { environment } from "@/services/environment";
|
|
||||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||||
|
|
||||||
export enum Flag {
|
export enum Flag {
|
||||||
@@ -19,9 +18,24 @@ export enum Flag {
|
|||||||
CHAT = "chat",
|
CHAT = "chat",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type FlagValues = {
|
||||||
|
[Flag.BETA_BLOCKS]: string[];
|
||||||
|
[Flag.NEW_BLOCK_MENU]: boolean;
|
||||||
|
[Flag.NEW_AGENT_RUNS]: boolean;
|
||||||
|
[Flag.GRAPH_SEARCH]: boolean;
|
||||||
|
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
|
||||||
|
[Flag.NEW_FLOW_EDITOR]: boolean;
|
||||||
|
[Flag.BUILDER_VIEW_SWITCH]: boolean;
|
||||||
|
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
||||||
|
[Flag.AGENT_FAVORITING]: boolean;
|
||||||
|
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
|
||||||
|
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
|
||||||
|
[Flag.CHAT]: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||||
|
|
||||||
const defaultFlags = {
|
const mockFlags = {
|
||||||
[Flag.BETA_BLOCKS]: [],
|
[Flag.BETA_BLOCKS]: [],
|
||||||
[Flag.NEW_BLOCK_MENU]: false,
|
[Flag.NEW_BLOCK_MENU]: false,
|
||||||
[Flag.NEW_AGENT_RUNS]: false,
|
[Flag.NEW_AGENT_RUNS]: false,
|
||||||
@@ -36,16 +50,17 @@ const defaultFlags = {
|
|||||||
[Flag.CHAT]: false,
|
[Flag.CHAT]: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
type FlagValues = typeof defaultFlags;
|
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||||
|
|
||||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] {
|
|
||||||
const currentFlags = useFlags<FlagValues>();
|
const currentFlags = useFlags<FlagValues>();
|
||||||
const flagValue = currentFlags[flag];
|
const flagValue = currentFlags[flag];
|
||||||
const areFlagsEnabled = environment.areFeatureFlagsEnabled();
|
|
||||||
|
|
||||||
if (!areFlagsEnabled || isPwMockEnabled) {
|
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||||
return defaultFlags[flag];
|
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||||
|
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||||
|
|
||||||
|
if (!isLaunchDarklyConfigured || isPwMockEnabled) {
|
||||||
|
return mockFlags[flag];
|
||||||
}
|
}
|
||||||
|
|
||||||
return flagValue ?? defaultFlags[flag];
|
return flagValue ?? mockFlags[flag];
|
||||||
}
|
}
|
||||||
|
|||||||
1
classic/frontend/.gitignore
vendored
1
classic/frontend/.gitignore
vendored
@@ -8,7 +8,6 @@
|
|||||||
.buildlog/
|
.buildlog/
|
||||||
.history
|
.history
|
||||||
.svn/
|
.svn/
|
||||||
.next/
|
|
||||||
migrate_working_dir/
|
migrate_working_dir/
|
||||||
|
|
||||||
# IntelliJ related
|
# IntelliJ related
|
||||||
|
|||||||
Reference in New Issue
Block a user